From 44f906e8537fcec965e312d621c80556d6aa9bec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Jun 2023 20:16:57 +0300 Subject: metal : add f16 support --- ggml-metal.m | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) (limited to 'ggml-metal.m') diff --git a/ggml-metal.m b/ggml-metal.m index d721ac6..0953af6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -47,10 +47,11 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(rms_norm); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(rms_norm); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -498,6 +500,14 @@ void ggml_metal_graph_compute( // use custom matrix x vector kernel switch (src0t) { + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth0 = 64; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; case GGML_TYPE_Q4_0: { GGML_ASSERT(ne02 == 1); @@ -507,14 +517,6 @@ void ggml_metal_graph_compute( nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(ne02 == ne12); - - nth0 = 32; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } break; default: GGML_ASSERT(false && "not implemented"); }; @@ -551,6 +553,7 @@ void ggml_metal_graph_compute( } switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; default: GGML_ASSERT(false && "not implemented"); } -- cgit v1.2.3