diff options
author | Jiahao Li <liplus17@163.com> | 2023-07-23 19:00:37 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 14:00:37 +0300 |
commit | 83a00ce69bef9124c0702424a012ea799128b77d (patch) | |
tree | 835c093a8fe5ce1c50305b78e44442ea30c5e1ed /ggml-metal.m | |
parent | d2a43664f93ba30a84e42713bb69f936cbdacf2a (diff) |
metal : support bcast add & dup & cont op (#2323)
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 2810fa2..78a3b65 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -42,6 +42,7 @@ struct ggml_metal_context { id<MTLComputePipelineState> pipeline_##name GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); @@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); @@ -464,10 +466,16 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - [encoder setComputePipelineState:ctx->pipeline_add]; + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_add_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_add]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -919,7 +927,9 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_DUP: case GGML_OP_CPY: + case GGML_OP_CONT: { if (encoder == nil) { encoder = [command_buffer computeCommandEncoder]; |