aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal174
1 files changed, 103 insertions, 71 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 30d60fa..f094a1d 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
+ float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
+ sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@@ -405,39 +408,50 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
- float2 acc = {0.0f, 0.0f};
+ float acc = sumy;
for (int i = 0; i < 16; i++) {
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
- acc[1] += yl[i] + yl[i+16];
+ acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
+ sumf[row] += d * acc;
qb_curr = qb_next;
}
}
- for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
- }
-
- for (int row = 0; row < N_DST; row++) {
- // prefetch next x block
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
- // calculate
- float d = qb_curr.d;
- float2 acc = {0.0f, 0.0f};
- for (int i = 0; i < 16; i++) {
- acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
- acc[1] += yl[i] + yl[i+16];
+ if (nb % N_SIMDWIDTH == 0) {
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ }
}
- if (tiisg < nb % N_SIMDWIDTH) {
- sumf[row] += d * (acc[0] - 8.f*acc[1]);
+ } else {
+
+ float sumy = 0;
+ for (int i = 0; i < QK4_0 / 4; i++) {
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
- qb_curr = qb_next;
+ sumy *= (-8.f);
+
+ for (int row = 0; row < N_DST; row++) {
+ // prefetch next x block
+ qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
+
+ // calculate
+ float d = qb_curr.d;
+ float acc = sumy;
+ for (int i = 0; i < 16; i++) {
+ acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+ }
+ if (tiisg < nb % N_SIMDWIDTH) {
+ sumf[row] += d * acc;
+ }
+ qb_curr = qb_next;
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ }
}
}
}
@@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
- threadgroup float * sum [[threadgroup(0)]],
+ constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
- uint2 tpitg[[thread_position_in_threadgroup]],
- uint2 tptg[[threads_per_threadgroup]]) {
- const int nb = ne00/QK4_1;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
-
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int nb = ne00/QK4_0;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const float * y = (device const float *) src1 + r1*ne10;
+ block_q4_1 qb_curr, qb_next;
+ float4 y_curr[8]; // src1 vector cache
+ float sumf[N_DST]={0.f}, all_sum;
+ thread float * yl=(thread float *)y_curr;
- const uint nth = tptg.x*tptg.y;
- const uint ith = tptg.y*tpitg.x + tpitg.y;
-
- const int ix = tpitg.y/4; // 0 or 1
- const int iy = tpitg.y - 4*ix; // 0...3
-
- const int first = 4 * iy;
-
- float sumf = 0;
-
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
-
- const float d = (float)x[i].d;
- const float m = (float)x[i].m;
+ // bootstrap
+ qb_curr = x[tiisg];
+ // each thread in a SIMD group deals with 1 block.
+ for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
- device const uint8_t * xl = x[i].qs + first;
- device const float * yl = y + i * QK4_1 + first;
+ float sumy = 0;
+ for (int i = 0; i < QK4_0 / 4; i++) {
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+ }
- float2 acc = {0.0f, 0.0f};
+ for (int row = 0; row < N_DST; row++) {
+ // prefetch next x block
+ qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
- for (int j = 0; j < 4; ++j) {
+ // calculate
+ const float d = qb_curr.d;
+ const float m = qb_curr.m;
+ float acc = 0.f;
+ for (int i = 0; i < 16; i++) {
+ acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+ }
+ sumf[row] += d * acc + m * sumy;
+ qb_curr = qb_next;
+ }
+ }
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
+ if (nb % N_SIMDWIDTH == 0) {
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ }
+ }
+ } else {
+ float sumy = 0;
+ for (int i = 0; i < QK4_0 / 4; i++) {
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
+ sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
- sumf += acc[0] + acc[1];
- }
+ for (int row = 0; row < N_DST; row++) {
+ // prefetch next x block
+ qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
- sum[ith] = sumf;
+ // calculate
+ const float d = qb_curr.d;
+ const float m = qb_curr.m;
+ float acc = 0.f;
+ for (int i = 0; i < 16; i++) {
+ acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+ }
+ if (tiisg < nb % N_SIMDWIDTH) {
+ sumf[row] += d * acc + m * sumy;
+ }
+ qb_curr = qb_next;
- //
- // Accumulate the sum from all threads in the threadgroup
- //
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[r1*ne0 + r0] = sum[0];
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ }
+ }
}
}