From e9b66ee9829039d4ab54550d6222e42a0b31e52a Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sat, 10 Jun 2023 11:28:11 +0300 Subject: metal : add Q4_1 implementation (#1785) 23.3 ms / token, so just ~1% slower than q4_0. Achieves 290 GB/s memory throughput. Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'ggml-metal.m') diff --git a/ggml-metal.m b/ggml-metal.m index 5c9ecd7..167ebd4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -50,12 +50,14 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q2_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); @@ -141,12 +143,14 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q2_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); @@ -545,6 +549,15 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + } break; case GGML_TYPE_Q2_K: { GGML_ASSERT(ne02 == 1); @@ -596,7 +609,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q2_K) { @@ -623,6 +636,7 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; -- cgit v1.2.3