aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
authorMatteo Boschini <12133566+mbosc@users.noreply.github.com>2023-08-01 09:43:12 +0200
committerGitHub <noreply@github.com>2023-08-01 10:43:12 +0300
commit1873ff586bd8499a18f763632711bf15d253585e (patch)
treef5c52d81b59d9044b2cd2b3b584e05be268ec278 /ggml-metal.m
parent49e7cb5bb1f75c91dd5db7d2d88cbc11bd9ee0c5 (diff)
metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
* Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m33
1 files changed, 17 insertions, 16 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 74a6bff..3f098d3 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne00 == ne10);
- GGML_ASSERT(ne02 == ne12);
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+ GGML_ASSERT(ne03 == ne13);
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
- // we need to do ne02 multiplications
+ // we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
- size_t offs_src0_cur = offs_src0 + i02*nb02;
+ for (int64_t i02 = 0; i02 < ne12; ++i02) {
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F16:
{
- GGML_ASSERT(ne02 == ne12);
-
nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {