aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2023-06-08 10:08:23 +0300
committerGitHub <noreply@github.com>2023-06-08 10:08:23 +0300
commit4161bdc04debb70bf5f275492b4d89fd9330087c (patch)
tree9b0c6325e720b101d67ec2415bc0d69e4fd89379 /ggml-metal.m
parent0035858273ebe0694926bf4414d279f3e1cd109d (diff)
metal : add Q4_K implementation (#1733)
* Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m23
1 files changed, 22 insertions, 1 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 0953af6..f2a637b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -49,9 +49,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+ GGML_METAL_DECL_KERNEL(get_rows_q4_k);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -133,9 +135,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+ GGML_METAL_ADD_KERNEL(get_rows_q4_k);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -517,7 +521,20 @@ void ggml_metal_graph_compute(
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
- default: GGML_ASSERT(false && "not implemented");
+ case GGML_TYPE_Q4_K:
+ {
+ GGML_ASSERT(ne02 == 1);
+ GGML_ASSERT(ne12 == 1);
+
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+ } break;
+ default:
+ {
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
};
@@ -540,6 +557,9 @@ void ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -555,6 +575,7 @@ void ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
default: GGML_ASSERT(false && "not implemented");
}