aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorJiahao Li <liplus17@163.com>2023-07-23 19:00:37 +0800
committerGitHub <noreply@github.com>2023-07-23 14:00:37 +0300
commit83a00ce69bef9124c0702424a012ea799128b77d (patch)
tree835c093a8fe5ce1c50305b78e44442ea30c5e1ed /ggml-metal.m
parentd2a43664f93ba30a84e42713bb69f936cbdacf2a (diff)
metal : support bcast add & dup & cont op (#2323)
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m12
1 files changed, 11 insertions, 1 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 2810fa2..78a3b65 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -42,6 +42,7 @@ struct ggml_metal_context {
id<MTLComputePipelineState> pipeline_##name
GGML_METAL_DECL_KERNEL(add);
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale);
@@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
GGML_METAL_ADD_KERNEL(add);
+ GGML_METAL_ADD_KERNEL(add_row);
GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale);
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
- [encoder setComputePipelineState:ctx->pipeline_add];
+ if (ggml_nelements(src1) == ne10) {
+ // src1 is a row
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_add];
+ }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst);
@@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
+ case GGML_OP_DUP:
case GGML_OP_CPY:
+ case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];