aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m19
-rw-r--r--ggml-metal.metal228
2 files changed, 136 insertions, 111 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 5e2a211..44d0468 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
- nth0 = 4;
- nth1 = 16;
+ nth0 = 2;
+ nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
- nth0 = 4;
- nth1 = 16;
+ nth0 = 2;
+ nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
@@ -743,11 +743,14 @@ void ggml_metal_graph_compute(
src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_Q3_K ||
- src0t == GGML_TYPE_Q4_K ||
- src0t == GGML_TYPE_Q5_K ||
- src0t == GGML_TYPE_Q6_K) {
+ src0t == GGML_TYPE_Q3_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
diff --git a/ggml-metal.metal b/ggml-metal.metal
index a9d134d..f71e8f3 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1642,39 +1642,39 @@ kernel void kernel_mul_mat_q5_K_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
- threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
- uint2 tpitg[[thread_position_in_threadgroup]],
- uint2 tptg[[threads_per_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
- const int nth = tptg.x*tptg.y;
- const int ith = tptg.y*tpitg.x + tpitg.y;
+ float sumf[2]={0.f};
- float sumf = 0;
+ const int step = sizeof(block_q5_K) * nb;
#if QK_K == 256
+#
+ float yl[16], yh[16];
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
- const int tid = tpitg.y; // 0...16
- const int il = tid/4; // 0...3
- const int ir = tid - 4*il;// 0...3
- const int n = 4;
-
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
- const int in = il%2;
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int im = tid/4;
+ const int ir = tid%4;
+ const int n = 8;
- const int l0 = n*(2*ir + in);
+ const int l0 = n*ir;
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
@@ -1683,78 +1683,114 @@ kernel void kernel_mul_mat_q5_K_f32(
const uint8_t hm3 = hm1 << 4;
const uint8_t hm4 = hm2 << 4;
- uchar2 sc1, sc2, sc3, sc4;
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
- for (int i = tpitg.x; i < nb; i += tptg.x) {
+ device const float * y1 = yy + ix*QK_K + y_offset;
- device const uint8_t * q1 = (x + i)->qs + q_offset;
- device const uint8_t * q2 = q1 + 64;
- device const uint8_t * qh = (x + i)->qh + l0;
- device const float * y1 = yy + i*QK_K + y_offset;
- device const float * y2 = y1 + 128;
+ for (int i = ix; i < nb; i += 4) {
- const float dall = (float)((x + i)->d);
- const float dmin = (float)((x + i)->dmin);
+ device const uint8_t * q1 = x[i].qs + q_offset;
+ device const uint8_t * qh = x[i].qh + l0;
+ device const half * dh = &x[i].d;
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+ device const float * y2 = y1 + 128;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 8; ++l) {
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+ }
- float4 s = {0.f, 0.f, 0.f, 0.f};
- float smin = 0;
- for (int l = 0; l < n; ++l) {
+ for (int row = 0; row < 2; ++row) {
+
+ device const uint8_t * q2 = q1 + 64;
+
+ sc16[0] = a[0] & kmask1;
+ sc16[1] = a[2] & kmask1;
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < n; ++l) {
+ uint8_t h = qh[l];
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+ }
+ const float dall = dh[0];
+ const float dmin = dh[1];
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+ q1 += step;
+ qh += step;
+ dh += step/2;
+ a += step/2;
}
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+ y1 += 4 * QK_K;
}
#else
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
- const int im = il/8; // 0, 0, 1, 1
- const int in = il%8; // 0, 4, 0, 4
+ float yl[8], yh[8];
- for (int i = tpitg.y; i < nb; i += tptg.y) {
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
+ const int ix = tiisg%8;
+ const int im = il/8; // 0, 0, 1, 1
+ const int in = il%8; // 0, 4, 0, 4
- const float d = (float)x[i].d;
+ device const float * y = yy + ix*QK_K + il;
+
+ for (int i = ix; i < nb; i += 8) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 4; ++l) {
+ yl[l+0] = y[l+ 0];
+ yl[l+4] = y[l+16];
+ yh[l+0] = y[l+32];
+ yh[l+4] = y[l+48];
+ }
+
+ device const half * dh = &x[i].d;
device const uint8_t * q = x[i].qs + il;
device const uint8_t * h = x[i].qh + in;
device const int8_t * s = x[i].scales;
- device const float * y = yy + i*QK_K + il;
- for (int l = 0; l < 4; ++l) {
- const uint8_t hl = h[l] >> im;
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
+ for (int row = 0; row < 2; ++row) {
+
+ const float d = dh[0];
+
+ float2 acc = {0.f, 0.f};
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t hl = h[l] >> im;
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
+ }
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
+
+ q += step;
+ h += step;
+ s += step;
+ dh += step/2;
+
}
+
+ y += 8 * QK_K;
}
#endif
- sum[ith] = sumf;
- //
- // Accumulate the sum from all threads in the threadgroup
- //
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[r1*ne0 + r0] = sum[0];
+ for (int row = 0; row < 2; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + first_row + row] = tot;
+ }
}
}
@@ -1766,10 +1802,9 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
- threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
- uint2 tpitg[[thread_position_in_threadgroup]],
- uint2 tptg[[threads_per_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@@ -1781,19 +1816,18 @@ kernel void kernel_mul_mat_q6_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const int row = 2 * r0 + sgitg;
- const int nth = tptg.x*tptg.y;
- const int ith = tptg.y*tpitg.x + tpitg.y;
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
+ device const float * yy = (device const float *) src1 + r1*ne10;
float sumf = 0;
#if QK_K == 256
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
- const int iqs = 16 * tpitg.y;
- const int ip = iqs / 128; // 0 or 1
- const int il = (iqs - 128*ip)/16; // 0...7
+ const int tid = tiisg/2;
+ const int ix = tiisg%2;
+ const int ip = tid/8; // 0 or 1
+ const int il = tid%8;
const int n = 4;
const int l0 = n*il;
const int is = 8*ip + l0/16;
@@ -1802,9 +1836,10 @@ kernel void kernel_mul_mat_q6_K_f32(
const int q_offset_l = 64*ip + l0;
const int q_offset_h = 32*ip + l0;
- for (int i = tpitg.x; i < nb; i += tptg.x) {
+ for (int i = ix; i < nb; i += 2) {
- device const uint8_t * ql = x[i].ql + q_offset_l;
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
+ device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
device const int8_t * sc = x[i].scales + is;
@@ -1814,19 +1849,21 @@ kernel void kernel_mul_mat_q6_K_f32(
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) {
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
}
+
#else
- const int il = 4*tpitg.x; // 0, 4, 8, 12
+ const int ix = tiisg/4;
+ const int il = 4*(tiisg%4);
- for (int i = tpitg.y; i < nb; i += tptg.y) {
+ for (int i = ix; i < nb; i += 8) {
device const float * y = yy + i * QK_K + il;
device const uint8_t * ql = x[i].ql + il;
device const uint8_t * qh = x[i].qh + il;
@@ -1846,23 +1883,8 @@ kernel void kernel_mul_mat_q6_K_f32(
#endif
- sum[ith] = sumf;
-
- //
- // Accumulate the sum from all threads in the threadgroup
- //
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[r1*ne0 + r0] = sum[0];
+ const float tot = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[r1*ne0 + row] = tot;
}
-
}