diff options
author | AT <manyoso@users.noreply.github.com> | 2023-06-09 04:00:51 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-09 11:00:51 +0300 |
commit | 92f44ff7f778ef1b94028b2ba6d39943b5ca0ada (patch) | |
tree | 851ddf5dbe9cbd01a44e8d516aac4c11c351e095 /ggml-metal.m | |
parent | 245fc3c37da5ac5963f9f11a9f4f2ac08d96afc6 (diff) |
metal : add GELU implementation (#1770)
Co-authored-by: Adam Treat <adam@nomic.ai>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 54cbaf8..5c9ecd7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -45,6 +45,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); @@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); @@ -420,6 +422,20 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { |