diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-06-06 20:16:57 +0300 |
---|---|---|
committer | Georgi Gerganov <ggerganov@gmail.com> | 2023-06-06 20:21:56 +0300 |
commit | 44f906e8537fcec965e312d621c80556d6aa9bec (patch) | |
tree | b9b705ed45c4541dda384d2b3fdf92391a16e8a8 /ggml-metal.m | |
parent | d5b111f53d14972669eb52055f9df2567663ad8b (diff) |
metal : add f16 support
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index d721ac6..0953af6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -47,10 +47,11 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(rms_norm); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(rms_norm); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -498,6 +500,14 @@ void ggml_metal_graph_compute( // use custom matrix x vector kernel switch (src0t) { + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth0 = 64; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; case GGML_TYPE_Q4_0: { GGML_ASSERT(ne02 == 1); @@ -507,14 +517,6 @@ void ggml_metal_graph_compute( nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(ne02 == ne12); - - nth0 = 32; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } break; default: GGML_ASSERT(false && "not implemented"); }; @@ -551,6 +553,7 @@ void ggml_metal_graph_compute( } switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; default: GGML_ASSERT(false && "not implemented"); } |