aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-06-06 20:16:57 +0300
committerGeorgi Gerganov <ggerganov@gmail.com>2023-06-06 20:21:56 +0300
commit44f906e8537fcec965e312d621c80556d6aa9bec (patch)
treeb9b705ed45c4541dda384d2b3fdf92391a16e8a8 /ggml-metal.m
parentd5b111f53d14972669eb52055f9df2567663ad8b (diff)
metal : add f16 support
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m23
1 files changed, 13 insertions, 10 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index d721ac6..0953af6 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -47,10 +47,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
+ GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(rms_norm);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
+ GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(rms_norm);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
// use custom matrix x vector kernel
switch (src0t) {
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(ne02 == ne12);
+
+ nth0 = 64;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ } break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(ne02 == ne12);
-
- nth0 = 32;
- nth1 = 1;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
- } break;
default: GGML_ASSERT(false && "not implemented");
};
@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
}
switch (src0->type) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: GGML_ASSERT(false && "not implemented");
}