diff options
author | Jiahao Li <liplus17@163.com> | 2023-07-23 19:00:37 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 14:00:37 +0300 |
commit | 83a00ce69bef9124c0702424a012ea799128b77d (patch) | |
tree | 835c093a8fe5ce1c50305b78e44442ea30c5e1ed | |
parent | d2a43664f93ba30a84e42713bb69f936cbdacf2a (diff) |
metal : support bcast add & dup & cont op (#2323)
-rw-r--r-- | ggml-metal.m | 12 | ||||
-rw-r--r-- | ggml-metal.metal | 11 |
2 files changed, 22 insertions, 1 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 2810fa2..78a3b65 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -42,6 +42,7 @@ struct ggml_metal_context { id<MTLComputePipelineState> pipeline_##name GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); @@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); @@ -464,10 +466,16 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - [encoder setComputePipelineState:ctx->pipeline_add]; + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_add_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_add]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -919,7 +927,9 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_DUP: case GGML_OP_CPY: + case GGML_OP_CONT: { if (encoder == nil) { encoder = [command_buffer computeCommandEncoder]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 5a9a6d8..987376d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -67,6 +67,17 @@ kernel void kernel_add( dst[tpig] = src0[tpig] + src1[tpig]; } +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % ne00]; +} + kernel void kernel_mul( device const float * src0, device const float * src1, |