diff options
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r-- | ggml-metal.metal | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal index 696b33c..8d26b5e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant int64_t & ne02, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); sum[tpitg.x] = 0.0f; @@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32( } } + kernel void kernel_alibi_f32( device const float * src0, device float * dst, |