aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorAaron Miller <apage43@ninjawhale.com>2023-06-17 07:37:49 -0700
committerGitHub <noreply@github.com>2023-06-17 17:37:49 +0300
commit0711a5f6dce7f04c2a791b14bc47f7d4cb545408 (patch)
tree9a14de4dbc3eb6fdc14f8838f81441ebb50eff08 /ggml-metal.m
parentfc45a81bc642b9ef33d9004f2b363d558438a6c9 (diff)
metal : add norm, cpy f16->f16, alibi kernels (#1823)
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m73
1 files changed, 73 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 0e9b56a..8148512 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -57,6 +57,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
GGML_METAL_DECL_KERNEL(rms_norm);
+ GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
@@ -66,8 +67,10 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_DECL_KERNEL(rope);
+ GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f16);
#undef GGML_METAL_DECL_KERNEL
};
@@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
GGML_METAL_ADD_KERNEL(rms_norm);
+ GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
@@ -171,8 +175,10 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_ADD_KERNEL(rope);
+ GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f16);
#undef GGML_METAL_ADD_KERNEL
}
@@ -735,6 +741,65 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_NORM:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ const float eps = 1e-5f;
+
+ const int nth = 256;
+
+ [encoder setComputePipelineState:ctx->pipeline_norm];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
+ const int n_past = ((int32_t *) src1->data)[0];
+ const int n_head = ((int32_t *) src1->data)[1];
+ const float max_bias = ((float *) src1->data)[2];
+ if (__builtin_popcount(n_head) != 1) {
+ GGML_ASSERT(false && "only power-of-two n_head implemented");
+ }
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+ [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ const int nth = 32;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
@@ -788,6 +853,14 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
};
} break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+ case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
default: GGML_ASSERT(false && "not implemented");
}