aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Boschini <12133566+mbosc@users.noreply.github.com>2023-08-01 09:43:12 +0200
committerGitHub <noreply@github.com>2023-08-01 10:43:12 +0300
commit1873ff586bd8499a18f763632711bf15d253585e (patch)
treef5c52d81b59d9044b2cd2b3b584e05be268ec278
parent49e7cb5bb1f75c91dd5db7d2d88cbc11bd9ee0c5 (diff)
metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
* Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
-rw-r--r--ggml-metal.m33
-rw-r--r--ggml-metal.metal5
2 files changed, 21 insertions, 17 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 74a6bff..3f098d3 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne00 == ne10);
- GGML_ASSERT(ne02 == ne12);
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+ GGML_ASSERT(ne03 == ne13);
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
- // we need to do ne02 multiplications
+ // we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
- size_t offs_src0_cur = offs_src0 + i02*nb02;
+ for (int64_t i02 = 0; i02 < ne12; ++i02) {
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F16:
{
- GGML_ASSERT(ne02 == ne12);
-
nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 696b33c..8d26b5e 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
+ constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
+ constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
@@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
- device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f;
@@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
}
}
+
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,