aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal5
1 files changed, 4 insertions, 1 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 696b33c..8d26b5e 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
+ constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
+ constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
@@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
- device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f;
@@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
}
}
+
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,