aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m9
-rw-r--r--ggml-metal.metal175
2 files changed, 104 insertions, 80 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 44d0468..135bda9 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
- nth0 = 4;
- nth1 = 16;
+ nth0 = 2;
+ nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
@@ -740,7 +740,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
- src0t == GGML_TYPE_Q4_K) {
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q5_K) {
@@ -749,8 +749,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_Q3_K) {
+ else if (src0t == GGML_TYPE_Q3_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
diff --git a/ggml-metal.metal b/ggml-metal.metal
index f71e8f3..97f5c10 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1209,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
- threadgroup float * sum [[threadgroup(0)]],
+ constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
- uint2 tpitg[[thread_position_in_threadgroup]],
- uint2 tptg[[threads_per_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
-
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
-
- const int nth = tptg.x*tptg.y;
- const int ith = tptg.y*tpitg.x + tpitg.y;
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
+ device const float * y = (device const float *) src1 + r1*ne10;
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
- float sumf = 0;
+ const int step = sizeof(block_q2_K) * nb;
#if QK_K == 256
- const int tid = tpitg.y; // 0...16
- const int il = tid/4; // 0...3
- const int ir = tid%4; // 0...3
- const int ip = il/2; // 0 or 1
- const int shift1 = 4*(il%2);// 0 or 4
- const int shift2 = shift1+2;// 2 or 6
- const int n = 8;
- const int is = 4*il + (n*ir)/16;
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int im = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+ const int is = (8*ir)/16;// 0 or 1
- const int y_offset = 64*il + n*ir;
- const int q_offset = 32*ip + n*ir;
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
- for (int i = tpitg.x; i < nb; i += tptg.x) {
+ for (int ib = ix; ib < nb; ib += 4) {
- device const uint8_t * q = x[i].qs + q_offset;
- device const uint8_t * scales = x[i].scales + is;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+ }
- uint8_t d1 = scales[0] & 0xF;
- uint8_t d2 = scales[2] & 0xF;
- uint8_t m1 = scales[0] >> 4;
- uint8_t m2 = scales[2] >> 4;
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+ device const half * dh = &x[ib].d;
- device const float * y = yy + i*QK_K + y_offset;
+ for (int row = 0; row < N_DST; row++) {
- float2 s = {0.f, 0.f};
- float smin = 0;
- for (int l = 0; l < n; ++l) {
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
- smin += y[l+ 0] * m1 + y[l+32] * m2;
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+ float dall = dh[0];
+ float dmin = dh[1] * 1.f/16.f;
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+ qs += step/2;
+ sc += step;
+ dh += step/2;
}
- const float dall = (float)x[i].d;
- const float dmin = (float)x[i].dmin;
-
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
-
+ y4 += 4 * QK_K;
}
#else
- const int il = 4 * tpitg.x;
+ const int ix = tiisg/2; // 0...15
+ const int it = tiisg%2; // 0...1
- uint32_t aux[2];
- thread const uint8_t * d = (thread const uint8_t *)aux;
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
+ device const float * y4 = y + ix * QK_K + 8 * it;
- for (int i = tpitg.y; i < nb; i += tptg.y) {
+ for (int ib = ix; ib < nb; ib += 16) {
- device const uint8_t * q = x[i].qs + il;
- device const float * y = yy + i*QK_K + il;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
+ }
- const float dall = (float)x[i].d;
- const float dmin = (float)x[i].dmin;
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+ device const half * dh = &x[ib].d;
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
- aux[0] = a[0] & 0x0f0f0f0f;
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+ for (int row = 0; row < N_DST; row++) {
- for (int l = 0; l < 4; ++l) {
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
+
+ qs += step/2;
+ sc += step;
+ dh += step/2;
}
+
+ y4 += 16 * QK_K;
}
#endif
- sum[ith] = sumf;
-
- //
- // Accumulate the sum from all threads in the threadgroup
- //
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[r1*ne0 + r0] = sum[0];
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + first_row + row] = all_sum;
+ }
}
}