From 1cbf561466e957b25f0e8163c2386683f8674369 Mon Sep 17 00:00:00 2001 From: Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:10:55 -0400 Subject: metal : new q4_0 matrix-vector kernel (#2188) Prefetch data to improve GPU utilization. ~48% faster for 33B model. --- ggml-metal.metal | 103 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 56 insertions(+), 47 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index e62fe68..30d60fa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -365,6 +365,10 @@ kernel void kernel_rms_norm( } } +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, @@ -372,64 +376,69 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK4_0; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; + block_q4_0 qb_curr, qb_next; + float4 y_curr[8]; // src1 vector cache + float sumf[N_DST]={0.f}, all_sum; + thread float * yl=(thread float *)y_curr; + + // bootstrap + qb_curr = x[tiisg]; + // each thread in a SIMD group deals with 1 block. + for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + } - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - const int ix = tpitg.y/4; // 0 or 1 - const int iy = tpitg.y - 4*ix; // 0...3 - - const int first = 4 * iy; - - float sumf = 0; + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { + // calculate + float d = qb_curr.d; + float2 acc = {0.0f, 0.0f}; + for (int i = 0; i < 16; i++) { + acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + acc[1] += yl[i] + yl[i+16]; + } + sumf[row] += d * (acc[0] - 8.f*acc[1]); + qb_curr = qb_next; + } + } - const float d = (float)x[i].d; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + } - device const uint8_t * xl = x[i].qs + first; - device const float * yl = y + i * QK4_0 + first; + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; + // calculate + float d = qb_curr.d; float2 acc = {0.0f, 0.0f}; - - for (int j = 0; j < 4; ++j) { - - acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4); - acc[1] += yl[j] + yl[j+16]; - + for (int i = 0; i < 16; i++) { + acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + acc[1] += yl[i] + yl[i+16]; } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * (acc[0] - 8.f*acc[1]); + } + qb_curr = qb_next; - sumf += d * (acc[0] - 8.f*acc[1]); - } - - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } } -- cgit v1.2.3 From 27ad57a69b85bf12420a27e9945e580cc280be57 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 14 Jul 2023 12:46:21 +0300 Subject: Metal: faster Q4_0 and Q4_1 matrix x vector kernels (#2212) * 3-5% faster Q4_0 on Metal * 7-25% faster Q4_1 on Metal * Oops, forgot to delete the original Q4_1 kernel --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 174 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 103 insertions(+), 71 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 30d60fa..f094a1d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32( // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } + sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -405,39 +408,50 @@ kernel void kernel_mul_mat_q4_0_f32( // calculate float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; + float acc = sumy; for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); } - sumf[row] += d * (acc[0] - 8.f*acc[1]); + sumf[row] += d * acc; qb_curr = qb_next; } } - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; - for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + if (nb % N_SIMDWIDTH == 0) { + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * (acc[0] - 8.f*acc[1]); + } else { + + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - qb_curr = qb_next; + sumy *= (-8.f); + + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; + + // calculate + float d = qb_curr.d; + float acc = sumy; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * acc; + } + qb_curr = qb_next; - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } } } @@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - const int nb = ne00/QK4_1; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/QK4_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; + block_q4_1 qb_curr, qb_next; + float4 y_curr[8]; // src1 vector cache + float sumf[N_DST]={0.f}, all_sum; + thread float * yl=(thread float *)y_curr; - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; - - const int ix = tpitg.y/4; // 0 or 1 - const int iy = tpitg.y - 4*ix; // 0...3 - - const int first = 4 * iy; - - float sumf = 0; - - for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - - const float d = (float)x[i].d; - const float m = (float)x[i].m; + // bootstrap + qb_curr = x[tiisg]; + // each thread in a SIMD group deals with 1 block. + for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - device const uint8_t * xl = x[i].qs + first; - device const float * yl = y + i * QK4_1 + first; + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + } - float2 acc = {0.0f, 0.0f}; + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - for (int j = 0; j < 4; ++j) { + // calculate + const float d = qb_curr.d; + const float m = qb_curr.m; + float acc = 0.f; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + sumf[row] += d * acc + m * sumy; + qb_curr = qb_next; + } + } - acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); - acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); + if (nb % N_SIMDWIDTH == 0) { + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } + } + } else { + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumf += acc[0] + acc[1]; - } + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - sum[ith] = sumf; + // calculate + const float d = qb_curr.d; + const float m = qb_curr.m; + float acc = 0.f; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * acc + m * sumy; + } + qb_curr = qb_next; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } + } } } -- cgit v1.2.3 From 6e7cca404748dd4b1a3affd0d1296e37f4ac0a6f Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Sat, 15 Jul 2023 06:34:16 -0400 Subject: llama : add custom RoPE (#2054) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement customizable RoPE The original RoPE has pre-defined parameters theta_i = 10000^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5 * ggml-metal: fix custom rope * common: fix argument names in help * llama: increase MEM_REQ_EVAL for MODEL_3B It avoids crashing for quantized weights on CPU. Better ways to calculate the required buffer size would be better. * llama: make MEM_REQ_EVAL depend on n_ctx * server: use proper Content-Type in curl examples Without the header Content-Type: application/json, curl will POST with Content-Type: application/x-www-form-urlencoded Though our simple server doesn't care, the httplib.h used has a limit with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 With Content-Type: application/json, we can send large json data. * style : minor fixes, mostly indentations * ggml : fix asserts --------- Co-authored-by: Georgi Gerganov --- ggml-metal.metal | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index f094a1d..9f9a4fb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -656,17 +656,19 @@ kernel void kernel_rope( constant int & n_past, constant int & n_dims, constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint3 tpig[[thread_position_in_grid]]) { const int64_t i3 = tpig[2]; const int64_t i2 = tpig[1]; const int64_t i1 = tpig[0]; const bool is_neox = mode & 2; - const float theta_scale = pow(10000.0, -2.0f/n_dims); + const float theta_scale = pow(freq_base, -2.0f/n_dims); const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); - float theta = (float)p; + float theta = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { -- cgit v1.2.3 From 417a85a0010519224cf154eb85d383ffeafeeead Mon Sep 17 00:00:00 2001 From: Shouzheng Liu Date: Thu, 20 Jul 2023 06:32:22 -0400 Subject: metal: minor q4 optimization and reduce code size (#2248) * metal: use uint16_t instead of uint8_t. Apple GPU doesn't like uint8_t. For every operation on uint8_t the gpu need to copy the uint8_t to an empty 16 bit register, then it can issue other instructions. For the matrix-vector multiplication kernel only, we observed a 340~350 GB/s memory read speed on M1 Max after this commit, which is very close to the reported hardware limit. * metal: update rms_norm kernel This commit double the speed of rms_norm operations by using 512 threads per threadgroup, combining with SIMD primitives to minimize the need for thread group barriers. * metal: use template to reduce size Revert modifications on block_q4_0 and block_q4_1. --- ggml-metal.metal | 236 +++++++++++++++++++++---------------------------------- 1 file changed, 91 insertions(+), 145 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 9f9a4fb..ee56336 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -331,26 +331,33 @@ kernel void kernel_rms_norm( threadgroup float * sum [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + float4 sumf=0; + float all_sum=0; // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (tiisg == 0) { + sum[sgitg] = all_sum; } - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 + for (int i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } } - - // broadcast if (tpitg == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} sum[0] /= ne00; } @@ -359,104 +366,102 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float * y_scalar = (device float *) y; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { y[i00] = x[i00] * scale; } + if (tpitg == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + } +} + +// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 1); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); +} + +// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float m = qb_curr->m; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 2); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -kernel void kernel_mul_mat_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +template +void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, + int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, + uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; - block_q4_0 qb_curr, qb_next; float4 y_curr[8]; // src1 vector cache float sumf[N_DST]={0.f}, all_sum; thread float * yl=(thread float *)y_curr; - // bootstrap - qb_curr = x[tiisg]; // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc; - qb_curr = qb_next; + sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); } } - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - + // from now loads two rows every time and 16 blocks per row + int ir = tiisg / (N_SIMDWIDTH / 2); + int ib = tiisg % (N_SIMDWIDTH / 2); + for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { + int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int row = 0; row < N_DST; row+=2) { + if (nb_start + ib < nb) { + sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc; - } - qb_curr = qb_next; + } + } - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; } } } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -467,80 +472,21 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; - device const float * y = (device const float *) src1 + r1*ne10; - block_q4_1 qb_curr, qb_next; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; - - // bootstrap - qb_curr = x[tiisg]; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc + m * sumy; - qb_curr = qb_next; - } - } - - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc + m * sumy; - } - qb_curr = qb_next; + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); +} - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( -- cgit v1.2.3 From 785829dfe8baf0213f2ff66963d28c62f92d7930 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 20 Jul 2023 15:18:43 +0300 Subject: Faster Q4_K on Metal (#2290) Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 236 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 143 insertions(+), 93 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index ee56336..a9d134d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1452,6 +1452,7 @@ kernel void kernel_mul_mat_q3_K_f32( } +#if QK_K == 256 kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, @@ -1459,131 +1460,180 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - float sumf = 0; - -#if QK_K == 256 + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; + const int step = sizeof(block_q4_K) * nb / 2; - uchar2 sc1, sc2, sc3, sc4; + device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; + for (int ib = ix; ib < nb; ib += 4) { - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { + for (int row = 0; row < N_DST; row++) { - s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); - s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + y4 += 4 * QK_K; } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } + } +} #else - uint16_t aux16[2]; - thread const uint8_t * scales = (thread const uint8_t *)aux16; +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int il = 4*tpitg.x; + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 - for (int i = tpitg.y; i < nb; i += tptg.y) { + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i * QK_K + il; + const int step = sizeof(block_q4_K) * nb / 2; - const float d = (float)x[i].d[0]; - const float m = (float)x[i].d[1]; + device const float * y4 = y + ix * QK_K + 8 * it; - device const uint16_t * a = (device const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; + uint16_t sc16[4]; - for (int l = 0; l < 4; ++l) { - sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) - + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; } - } -#endif - sum[ith] = sumf; + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} + y4 += 8 * QK_K; + } - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } + } } +#endif kernel void kernel_mul_mat_q5_K_f32( device const void * src0, -- cgit v1.2.3 From e782c9e735f93ab4767ffc37462c523b73a17ddc Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 20 Jul 2023 18:19:45 +0300 Subject: Faster Q5_K and Q6_K on Metal (#2294) * Faster Q6_K on Metal * Faster Q5_K on Metal * Another Q5_K speedup --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 228 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 125 insertions(+), 103 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index a9d134d..f71e8f3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1642,39 +1642,39 @@ kernel void kernel_mul_mat_q5_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf[2]={0.f}; - float sumf = 0; + const int step = sizeof(block_q5_K) * nb; #if QK_K == 256 +# + float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; + const int tid = tiisg/4; + const int ix = tiisg%4; + const int im = tid/4; + const int ir = tid%4; + const int n = 8; - const int l0 = n*(2*ir + in); + const int l0 = n*ir; const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; @@ -1683,78 +1683,114 @@ kernel void kernel_mul_mat_q5_K_f32( const uint8_t hm3 = hm1 << 4; const uint8_t hm4 = hm2 << 4; - uchar2 sc1, sc2, sc3, sc4; + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - for (int i = tpitg.x; i < nb; i += tptg.x) { + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const uint8_t * qh = (x + i)->qh + l0; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; + for (int i = ix; i < nb; i += 4) { - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); + acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); + acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); + acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0)); - s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0)); - s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0)); - s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0)); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + q1 += step; + qh += step; + dh += step/2; + a += step/2; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + + y1 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 + float yl[8], yh[8]; - for (int i = tpitg.y; i < nb; i += tptg.y) { + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 - const float d = (float)x[i].d; + device const float * y = yy + ix*QK_K + il; + + for (int i = ix; i < nb; i += 8) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; device const uint8_t * q = x[i].qs + il; device const uint8_t * h = x[i].qh + in; device const int8_t * s = x[i].scales; - device const float * y = yy + i*QK_K + il; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) - + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) - + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) - + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + } + + y += 8 * QK_K; } #endif - sum[ith] = sumf; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } } } @@ -1766,10 +1802,9 @@ kernel void kernel_mul_mat_q6_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -1781,19 +1816,18 @@ kernel void kernel_mul_mat_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const int row = 2 * r0 + sgitg; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; float sumf = 0; #if QK_K == 256 - // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! - const int iqs = 16 * tpitg.y; - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/16; // 0...7 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; const int n = 4; const int l0 = n*il; const int is = 8*ip + l0/16; @@ -1802,9 +1836,10 @@ kernel void kernel_mul_mat_q6_K_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int i = ix; i < nb; i += 2) { - device const uint8_t * ql = x[i].ql + q_offset_l; + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; @@ -1814,19 +1849,21 @@ kernel void kernel_mul_mat_q6_K_f32( float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } + #else - const int il = 4*tpitg.x; // 0, 4, 8, 12 + const int ix = tiisg/4; + const int il = 4*(tiisg%4); - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int i = ix; i < nb; i += 8) { device const float * y = yy + i * QK_K + il; device const uint8_t * ql = x[i].ql + il; device const uint8_t * qh = x[i].qh + il; @@ -1846,23 +1883,8 @@ kernel void kernel_mul_mat_q6_K_f32( #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } - } -- cgit v1.2.3 From e68c96f7fee8fc22814a4a1209ffc97bbf35f7bd Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 21 Jul 2023 10:44:40 +0300 Subject: Faster Q2_K on Metal (#2297) * Faster Q2_K on Metal * Deleting unnoticed and dangereous trailing white space * Fixed bug in new metal Q2_K implementation --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 175 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 100 insertions(+), 75 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index f71e8f3..97f5c10 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1209,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - float sumf = 0; + const int step = sizeof(block_q2_K) * nb; #if QK_K == 256 - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int ip = il/2; // 0 or 1 - const int shift1 = 4*(il%2);// 0 or 4 - const int shift2 = shift1+2;// 2 or 6 - const int n = 8; - const int is = 4*il + (n*ir)/16; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 - const int y_offset = 64*il + n*ir; - const int q_offset = 32*ip + n*ir; + device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int ib = ix; ib < nb; ib += 4) { - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * scales = x[i].scales + is; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } - uint8_t d1 = scales[0] & 0xF; - uint8_t d2 = scales[2] & 0xF; - uint8_t m1 = scales[0] >> 4; - uint8_t m2 = scales[2] >> 4; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - device const float * y = yy + i*QK_K + y_offset; + for (int row = 0; row < N_DST; row++) { - float2 s = {0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); - s[1] += y[l+32] * ((q[l] >> shift2) & 3); - smin += y[l+ 0] * m1 + y[l+32] * m2; + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; - - sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; - + y4 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 - uint32_t aux[2]; - thread const uint8_t * d = (thread const uint8_t *)aux; - thread const uint8_t * m = (thread const uint8_t *)aux + 4; + device const float * y4 = y + ix * QK_K + 8 * it; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int ib = ix; ib < nb; ib += 16) { - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i*QK_K + il; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; - device const uint32_t * a = (device const uint32_t *)x[i].scales; - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = (a[0] >> 4) & 0x0f0f0f0f; + for (int row = 0; row < N_DST; row++) { - for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) - + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) - + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) - + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; } + + y4 += 16 * QK_K; } #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } } } -- cgit v1.2.3 From 4d76a5f49b9b5382dba5d13d92edb9159536c225 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 21 Jul 2023 17:05:30 +0300 Subject: Faster Q3_K implementation on Metal (#2307) * Faster Q3_K on Metal * Additional Q3_K speedup on Metal * Q3_K for QK_K = 64 * Better Q3_K for QK_K = 64 21.6 ms/t -> 21.1 ms/t --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 196 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 118 insertions(+), 78 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 97f5c10..5a9a6d8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -351,7 +351,7 @@ kernel void kernel_rms_norm( threadgroup_barrier(mem_flags::mem_threadgroup); // broadcast, simd group number is ntg / 32 - for (int i = ntg / 32 / 2; i > 0; i /= 2) { + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { sum[tpitg] += sum[tpitg + i]; } @@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32( } } +#if QK_K == 256 kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, @@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; -#if QK_K == 256 + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; + device const float * yy = (device const float *) src1 + r1*ne10; - const uint8_t m3 = 3; - const int8_t m4 = 4; + float yl[16]; const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = tpitg.y; // expecting 16 + const int tid = tiisg/2; + const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 const int ir = tid%2; const int n = 8; const int l0 = n*ir; - const uint8_t m = 1 << (4*ip + il); + const uint16_t m1 = 1 << (4*ip + il); + const uint16_t m2 = m1 << 8; const int shift = 2*il; + const uint16_t qm1 = 0x0003 << shift; + const uint16_t qm2 = 0x0300 << shift; + const int32_t v1 = 4 << shift; + const int32_t v2 = 1024 << shift; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + 2*(il/2); @@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - //float sumf = 0; - float sumf1 = 0, sumf2 = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + const int step = sizeof(block_q3_K) * nb / 2; - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * h = x[i].hmask + l0; - device const float * y = yy + i * QK_K + y_offset; + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint16_t * a = (device const uint16_t *)x[i].scales; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + float sumf1[2] = {0.f}, sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 2) { - float s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; + yl[l+8] = y1[l+16]; } - float d = d_all * s; - sumf1 += d * scales[0]; - sumf2 += d; - //sumf += d_all * s * (scales[0] - 32); - s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + + float s1 = 0, s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); + } + float d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[0]; + sumf2[row] += d; + + s1 = s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2+8]; + s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); + s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + } + d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[1]; + sumf2[row] += d; + + q += step; + h += step; + a += step; + dh += step; + } - d = d_all * s; - sumf1 += d * scales[1]; - sumf2 += d; - //sumf += d_all * s * (scales[1] - 32); + + y1 += 2 * QK_K; } - //sum[ith] = sumf; - sum[ith] = sumf1 - 32.f*sumf2; + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } + } +} #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 +kernel void kernel_mul_mat_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne1, + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + const int row = 2 * r0 + sgitg; + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 const int in = il%8; // 0, 4, 0, 4 - float sumf = 0; + float2 sum = {0.f, 0.f}; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int i = ix; i < nb; i += 8) { const float d_all = (float)(x[i].d); - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].hmask + in; - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - - for (int l = 0; l < 4; ++l) { - const uint8_t hm = h[l] >> im; - sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) - + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) - + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) - + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; + + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> im; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); } } + const float sumf = sum[0] + sum[1] * 1.f/256.f; - sum[ith] = sumf; - -#endif - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } } +#endif #if QK_K == 256 kernel void kernel_mul_mat_q4_K_f32( @@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32( for (int i = ix; i < nb; i += 8) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { yl[l+0] = y[l+ 0]; yl[l+4] = y[l+16]; -- cgit v1.2.3 From 83a00ce69bef9124c0702424a012ea799128b77d Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Sun, 23 Jul 2023 19:00:37 +0800 Subject: metal : support bcast add & dup & cont op (#2323) --- ggml-metal.metal | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 5a9a6d8..987376d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -67,6 +67,17 @@ kernel void kernel_add( dst[tpig] = src0[tpig] + src1[tpig]; } +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % ne00]; +} + kernel void kernel_mul( device const float * src0, device const float * src1, -- cgit v1.2.3 From 9a08eaf3c4010962d0126e9e5bfbe9af64b2ac90 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Tue, 25 Jul 2023 13:48:29 +0300 Subject: Another speed gain for Q4_0 and Q4_1 on Metal (#2375) * Another speed gain for Q4_0 and Q4_1 on Metal * Have N_DST, etc., be template parameters --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.metal | 111 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 57 insertions(+), 54 deletions(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 987376d..696b33c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -387,87 +387,90 @@ kernel void kernel_rms_norm( } } -// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 1); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); + return d * (sumy * -8.f + acc[0] + acc[1]); } -// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) -float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - float4 acc = 0.f; - device uint16_t * qs = ((device uint16_t *)qb_curr + 2); - for (int i = 0; i < 16; i+=2) { - acc[0] += yl[i] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); - acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; + return d * (acc[0] + acc[1]) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -template +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// giard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + const int first_row = (r0 * nsg + sgitg) * nr; + device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb; device const float * y = (device const float *) src1 + r1*ne10; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; + float yl[16]; // src1 vector cache + float sumf[nr]={0.f}; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } + const int ix = tiisg/2; + const int il = 8*(tiisg%2); - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); - } - } + device const float * yb = y + ix * QK4_0 + il; - // from now loads two rows every time and 16 blocks per row - int ir = tiisg / (N_SIMDWIDTH / 2); - int ib = tiisg % (N_SIMDWIDTH / 2); - for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { - int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; } - for (int row = 0; row < N_DST; row+=2) { - if (nb_start + ib < nb) { - sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); - } + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); } + + yb += QK4_0 * 16; } - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + first_row + row] = tot; } } } @@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( -- cgit v1.2.3 From 1873ff586bd8499a18f763632711bf15d253585e Mon Sep 17 00:00:00 2001 From: Matteo Boschini <12133566+mbosc@users.noreply.github.com> Date: Tue, 1 Aug 2023 09:43:12 +0200 Subject: metal : add gqa8 kernel to allow llama-2-70B on metal (#2459) * Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre --- ggml-metal.metal | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'ggml-metal.metal') diff --git a/ggml-metal.metal b/ggml-metal.metal index 696b33c..8d26b5e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant int64_t & ne02, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); sum[tpitg.x] = 0.0f; @@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32( } } + kernel void kernel_alibi_f32( device const float * src0, device float * dst, -- cgit v1.2.3