aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m96
1 files changed, 53 insertions, 43 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index bf3f68f..1fd6e85 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
- case GGML_OP_SILU:
- {
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
- }
-
- [encoder setComputePipelineState:ctx->pipeline_silu];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_RELU:
- {
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
- }
-
- [encoder setComputePipelineState:ctx->pipeline_relu];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(gf->nodes[i])) {
+ case GGML_UNARY_OP_SILU:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ [encoder setComputePipelineState:ctx->pipeline_silu];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ [encoder setComputePipelineState:ctx->pipeline_relu];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ default:
+ {
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
} break;
- case GGML_OP_GELU:
- {
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
- }
-
- [encoder setComputePipelineState:ctx->pipeline_gelu];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ASSERT(false);
+ {
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
}
}