aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m7
-rw-r--r--ggml-metal.metal236
2 files changed, 147 insertions, 96 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index d80a380..5e2a211 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
- nth0 = 4;
- nth1 = 16;
+ nth0 = 2;
+ nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
@@ -739,7 +739,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+ src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
diff --git a/ggml-metal.metal b/ggml-metal.metal
index ee56336..a9d134d 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1452,6 +1452,7 @@ kernel void kernel_mul_mat_q3_K_f32(
}
+#if QK_K == 256
kernel void kernel_mul_mat_q4_K_f32(
device const void * src0,
device const float * src1,
@@ -1459,131 +1460,180 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
- threadgroup float * sum [[threadgroup(0)]],
+ constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
- uint2 tpitg[[thread_position_in_threadgroup]],
- uint2 tptg[[threads_per_threadgroup]]) {
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
-
- const int nth = tptg.x*tptg.y;
- const int ith = tptg.y*tpitg.x + tpitg.y;
-
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
-
- float sumf = 0;
-
-#if QK_K == 256
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
- const int tid = tpitg.y; // 0...16
- const int il = tid/4; // 0...3
- const int ir = tid - 4*il;// 0...3
- const int n = 4;
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int im = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const int in = il%2;
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+ device const float * y = (device const float *) src1 + r1*ne10;
+ float yl[16];
+ float yh[16];
+ float sumf[N_DST]={0.f}, all_sum;
- const int l0 = n*(2*ir + in);
- const int q_offset = 32*im + l0;
- const int y_offset = 64*im + l0;
+ const int step = sizeof(block_q4_K) * nb / 2;
- uchar2 sc1, sc2, sc3, sc4;
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
- for (int i = tpitg.x; i < nb; i += tptg.x) {
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
- device const uint8_t * q1 = (x + i)->qs + q_offset;
- device const uint8_t * q2 = q1 + 64;
- device const float * y1 = yy + i*QK_K + y_offset;
- device const float * y2 = y1 + 128;
+ for (int ib = ix; ib < nb; ib += 4) {
- const float dall = (float)((x + i)->d);
- const float dmin = (float)((x + i)->dmin);
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+ }
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+ device const half * dh = &x[ib].d;
- float4 s = {0.f, 0.f, 0.f, 0.f};
- float smin = 0;
- for (int l = 0; l < n; ++l) {
+ for (int row = 0; row < N_DST; row++) {
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+ sc16[0] = sc[0] & kmask1;
+ sc16[1] = sc[2] & kmask1;
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+ device const uint16_t * q2 = q1 + 32;
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+ }
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += step;
+ sc += step;
+ dh += step;
}
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+ y4 += 4 * QK_K;
}
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + first_row + row] = all_sum;
+ }
+ }
+}
#else
- uint16_t aux16[2];
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
+kernel void kernel_mul_mat_q4_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne10,
+ constant int64_t & ne0,
+ constant int64_t & ne01[[buffer(4)]],
+ uint2 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int il = 4*tpitg.x;
+ const int ix = tiisg/4; // 0...7
+ const int it = tiisg%4; // 0...3
- for (int i = tpitg.y; i < nb; i += tptg.y) {
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+ device const float * y = (device const float *) src1 + r1*ne10;
+ float yl[8];
+ float yh[8];
+ float sumf[N_DST]={0.f}, all_sum;
- device const uint8_t * q = x[i].qs + il;
- device const float * y = yy + i * QK_K + il;
+ const int step = sizeof(block_q4_K) * nb / 2;
- const float d = (float)x[i].d[0];
- const float m = (float)x[i].d[1];
+ device const float * y4 = y + ix * QK_K + 8 * it;
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
- aux16[0] = a[0] & 0x0f0f;
- aux16[1] = (a[0] >> 4) & 0x0f0f;
+ uint16_t sc16[4];
- for (int l = 0; l < 4; ++l) {
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
+ for (int ib = ix; ib < nb; ib += 8) {
+
+ float2 sumy = {0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
}
- }
-#endif
- sum[ith] = sumf;
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+ device const half * dh = x[ib].d;
- //
- // Accumulate the sum from all threads in the threadgroup
- // This version is slightly faster than the commented out one below,
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
- //
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[r1*ne0 + r0] = sum[0];
- }
+ for (int row = 0; row < N_DST; row++) {
+
+ sc16[0] = sc[0] & 0x000f;
+ sc16[1] = sc[0] & 0x0f00;
+ sc16[2] = sc[0] & 0x00f0;
+ sc16[3] = sc[0] & 0xf000;
+
+ float2 acc1 = {0.f, 0.f};
+ float2 acc2 = {0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
+
+ qs += step;
+ sc += step;
+ dh += step;
+ }
- //// accumulate the sum from all threads in the threadgroup
- //threadgroup_barrier(mem_flags::mem_threadgroup);
- //for (uint i = nth/2; i > 0; i /= 2) {
- // if (ith < i) {
- // sum[ith] += sum[ith + i];
- // }
- // threadgroup_barrier(mem_flags::mem_threadgroup);
- //}
+ y4 += 8 * QK_K;
+ }
- //if (ith == 0) {
- // dst[r1*ne0 + r0] = sum[0];
- //}
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + first_row + row] = all_sum;
+ }
+ }
}
+#endif
kernel void kernel_mul_mat_q5_K_f32(
device const void * src0,