diff options
-rw-r--r-- | ggml-metal.m | 4 | ||||
-rw-r--r-- | ggml-metal.metal | 236 |
2 files changed, 93 insertions, 147 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index ee205bc..d80a380 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -792,7 +792,7 @@ void ggml_metal_graph_compute( const float eps = 1e-6f; - const int nth = 256; + const int nth = 512; [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -800,7 +800,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 9f9a4fb..ee56336 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -331,26 +331,33 @@ kernel void kernel_rms_norm( threadgroup float * sum [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + float4 sumf=0; + float all_sum=0; // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (tiisg == 0) { + sum[sgitg] = all_sum; } - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 + for (int i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } } - - // broadcast if (tpitg == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} sum[0] /= ne00; } @@ -359,104 +366,102 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float * y_scalar = (device float *) y; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { y[i00] = x[i00] * scale; } + if (tpitg == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + } +} + +// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 1); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); +} + +// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float m = qb_curr->m; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 2); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -kernel void kernel_mul_mat_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +template<typename block_q_type> +void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, + int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, + uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; - block_q4_0 qb_curr, qb_next; float4 y_curr[8]; // src1 vector cache float sumf[N_DST]={0.f}, all_sum; thread float * yl=(thread float *)y_curr; - // bootstrap - qb_curr = x[tiisg]; // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc; - qb_curr = qb_next; + sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); } } - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - + // from now loads two rows every time and 16 blocks per row + int ir = tiisg / (N_SIMDWIDTH / 2); + int ib = tiisg % (N_SIMDWIDTH / 2); + for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { + int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int row = 0; row < N_DST; row+=2) { + if (nb_start + ib < nb) { + sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc; - } - qb_curr = qb_next; + } + } - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; } } } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; - device const float * y = (device const float *) src1 + r1*ne10; - block_q4_1 qb_curr, qb_next; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; - - // bootstrap - qb_curr = x[tiisg]; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc + m * sumy; - qb_curr = qb_next; - } - } - - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc + m * sumy; - } - qb_curr = qb_next; + mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); +} - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( |