diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 78a3b65..bf3f68f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -585,7 +585,7 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *)(dst->op_params))[0]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -850,9 +850,10 @@ void ggml_metal_graph_compute( GGML_ASSERT((src0t == GGML_TYPE_F32)); - const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); if (__builtin_popcount(n_head) != 1) { GGML_ASSERT(false && "only power-of-two n_head implemented"); @@ -890,15 +891,14 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; float freq_base; float freq_scale; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |