aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m4
-rw-r--r--ggml-metal.metal236
2 files changed, 93 insertions, 147 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index ee205bc..d80a380 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -792,7 +792,7 @@ void ggml_metal_graph_compute(
const float eps = 1e-6f;
- const int nth = 256;
+ const int nth = 512;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -800,7 +800,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 9f9a4fb..ee56336 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -331,26 +331,33 @@ kernel void kernel_rms_norm(
threadgroup float * sum [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+ device const float * x_scalar = (device const float *) x;
+ float4 sumf=0;
+ float all_sum=0;
// parallel sum
- sum[tpitg] = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- sum[tpitg] += x[i00] * x[i00];
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ sumf += x[i00] * x[i00];
+ }
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+ all_sum = simd_sum(all_sum);
+ if (tiisg == 0) {
+ sum[sgitg] = all_sum;
}
- // reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg/2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ // broadcast, simd group number is ntg / 32
+ for (int i = ntg / 32 / 2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
}
-
- // broadcast
if (tpitg == 0) {
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
sum[0] /= ne00;
}
@@ -359,104 +366,102 @@ kernel void kernel_rms_norm(
const float mean = sum[0];
const float scale = 1.0f/sqrt(mean + eps);
- device float * y = dst + tgpig*ne00;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
+ device float * y_scalar = (device float *) y;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
+ if (tpitg == 0) {
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+ }
+}
+
+// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+ float d = qb_curr->d;
+ float4 acc = 0.f;
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
+ for (int i = 0; i < 16; i+=2) {
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+ }
+ return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+}
+
+// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+ float4 acc = 0.f;
+ device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
+ for (int i = 0; i < 16; i+=2) {
+ acc[0] += yl[i] * (qs[i / 2] & 0x000F);
+ acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+ acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+ }
+ return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
}
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-kernel void kernel_mul_mat_q4_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+template<typename block_q_type>
+void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
+ uint2 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const float * y = (device const float *) src1 + r1*ne10;
- block_q4_0 qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;
- // bootstrap
- qb_curr = x[tiisg];
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
- sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
- // prefetch next x block
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
- // calculate
- float d = qb_curr.d;
- float acc = sumy;
- for (int i = 0; i < 16; i++) {
- acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
- }
- sumf[row] += d * acc;
- qb_curr = qb_next;
+ sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
}
}
- if (nb % N_SIMDWIDTH == 0) {
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
- }
- }
- } else {
-
+ // from now loads two rows every time and 16 blocks per row
+ int ir = tiisg / (N_SIMDWIDTH / 2);
+ int ib = tiisg % (N_SIMDWIDTH / 2);
+ for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
+ int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
+ y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
- sumy *= (-8.f);
- for (int row = 0; row < N_DST; row++) {
- // prefetch next x block
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
- // calculate
- float d = qb_curr.d;
- float acc = sumy;
- for (int i = 0; i < 16; i++) {
- acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+ for (int row = 0; row < N_DST; row+=2) {
+ if (nb_start + ib < nb) {
+ sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
}
- if (tiisg < nb % N_SIMDWIDTH) {
- sumf[row] += d * acc;
- }
- qb_curr = qb_next;
+ }
+ }
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
- }
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
}
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mat_q4_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32(
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int nb = ne00/QK4_0;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
- device const float * y = (device const float *) src1 + r1*ne10;
- block_q4_1 qb_curr, qb_next;
- float4 y_curr[8]; // src1 vector cache
- float sumf[N_DST]={0.f}, all_sum;
- thread float * yl=(thread float *)y_curr;
-
- // bootstrap
- qb_curr = x[tiisg];
- // each thread in a SIMD group deals with 1 block.
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
- float sumy = 0;
- for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
- }
-
- for (int row = 0; row < N_DST; row++) {
- // prefetch next x block
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
- // calculate
- const float d = qb_curr.d;
- const float m = qb_curr.m;
- float acc = 0.f;
- for (int i = 0; i < 16; i++) {
- acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
- }
- sumf[row] += d * acc + m * sumy;
- qb_curr = qb_next;
- }
- }
-
- if (nb % N_SIMDWIDTH == 0) {
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
- }
- }
- } else {
-
- float sumy = 0;
- for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
- }
-
- for (int row = 0; row < N_DST; row++) {
- // prefetch next x block
- qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
- // calculate
- const float d = qb_curr.d;
- const float m = qb_curr.m;
- float acc = 0.f;
- for (int i = 0; i < 16; i++) {
- acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
- }
- if (tiisg < nb % N_SIMDWIDTH) {
- sumf[row] += d * acc + m * sumy;
- }
- qb_curr = qb_next;
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+}
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
- }
- }
- }
+kernel void kernel_mul_mat_q4_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne10,
+ constant int64_t & ne0,
+ constant int64_t & ne01[[buffer(4)]],
+ uint2 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_f16_f32(