aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m66
1 files changed, 36 insertions, 30 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index a7e104d..7551231 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -51,21 +51,21 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
- GGML_METAL_DECL_KERNEL(get_rows_q2_k);
- GGML_METAL_DECL_KERNEL(get_rows_q3_k);
- GGML_METAL_DECL_KERNEL(get_rows_q4_k);
- GGML_METAL_DECL_KERNEL(get_rows_q5_k);
- GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+ GGML_METAL_DECL_KERNEL(get_rows_q2_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q3_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q4_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q5_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
exit(1);
}
+#ifdef GGML_QKK_64
+ MTLCompileOptions* options = [MTLCompileOptions new];
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+#else
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+#endif
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
- GGML_METAL_ADD_KERNEL(get_rows_q2_k);
- GGML_METAL_ADD_KERNEL(get_rows_q3_k);
- GGML_METAL_ADD_KERNEL(get_rows_q4_k);
- GGML_METAL_ADD_KERNEL(get_rows_q5_k);
- GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+ GGML_METAL_ADD_KERNEL(get_rows_q2_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q3_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q4_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q5_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -662,7 +668,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
@@ -671,7 +677,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
@@ -680,7 +686,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
@@ -689,7 +695,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
@@ -698,7 +704,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
{
@@ -750,11 +756,11 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
default: GGML_ASSERT(false && "not implemented");
}