diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2023-07-21 17:05:30 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-21 17:05:30 +0300 |
commit | 4d76a5f49b9b5382dba5d13d92edb9159536c225 (patch) | |
tree | 7bb4a3231985d1fb254cb5c38b65daba53cdbe4b /ggml-metal.m | |
parent | 0db14fef06836caaa13cc123c0a24dc598bdb9f0 (diff) |
Faster Q3_K implementation on Metal (#2307)
* Faster Q3_K on Metal
* Additional Q3_K speedup on Metal
* Q3_K for QK_K = 64
* Better Q3_K for QK_K = 64
21.6 ms/t -> 21.1 ms/t
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 135bda9..2810fa2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -685,8 +685,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: @@ -743,15 +743,18 @@ void ggml_metal_graph_compute( src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |