diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | Makefile | 5 | ||||
-rw-r--r-- | ggml-cuda.cu | 599 |
3 files changed, 385 insertions, 221 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index ea9f80b..dbbc0b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) @@ -201,6 +202,7 @@ if (LLAMA_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) @@ -171,6 +171,11 @@ ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 endif # LLAMA_CUDA_DMMV_Y +ifdef LLAMA_CUDA_KQUANTS_ITER + NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) +else + NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 +endif ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bd89d0a..7edd1a9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -167,6 +167,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define GGML_CUDA_DMMV_Y 1 #endif +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -326,37 +332,6 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { } -static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q2_K * x = (const block_q2_K *) vx; - - // if n is 0, we want to do the lower 128, else the upper 128, - // covering y[l+0], y[l+32], y[l+64], y[l+96] and - // y[l+16], y[l+48], y[l+80], y[l+112] - int n = iqs/128; // 0 or 1 - int r = iqs - 128*n; // 0...120 in steps of 8 - int l = r/8; // 0...15 in steps of 1 - - const float * y = yy + 128*n + l; - const uint8_t * q = x[ib].qs + 32*n + l; - const uint8_t * s = x[ib].scales + 8*n; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4)) - + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4)) - + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4)) - + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4)) - + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4)) - + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4)) - + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4)) - + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4)); - - result = sum; - -} - static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { int r = threadIdx.x/4; @@ -388,51 +363,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } -static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q3_K * x = (const block_q3_K *) vx; - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - uint32_t aux[3]; - uint32_t utmp[4]; - - // if n is 0, we want to do the lower 128, else the upper 128, - // covering y[l+0], y[l+32], y[l+64], y[l+96] and - // y[l+16], y[l+48], y[l+80], y[l+112] - int n = iqs/128; // 0 or 1 - int r = iqs - 128*n; // 0...120 in steps of 8 - int l = r/8; // 0...15 in steps of 1 - - const float * y = yy + 128*n + l; - const uint8_t * q = x[ib].qs + 32*n + l; - const uint8_t * hm = x[ib].hmask + l; - const int8_t * s = (const int8_t *)utmp + 8*n; - - memcpy(aux, x[ib].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - const float dall = x[ib].d; - - const uint8_t m = 1 << (4*n); - - float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) - + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) - + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) - + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) - + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) - + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) - + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) - + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); - - result = sum * dall; - -} - static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -479,38 +409,6 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { } } -static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q4_K * x = (const block_q4_K *) vx; - - // iqs is in 0...248 in steps of 8 => - const int j = iqs / 64; // j is in 0...3 - const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 - const int is = 2*j; // is is in 0...6 in steps of 2 - - const float * y = yy + 64*j + ir; - const uint8_t * q = x[ib].qs + 32*j + ir; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1); - sum += y[k + 32] * (d2 * (q[k] >> 4) - m2); - } - result = sum; - -} - static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const block_q5_K * x = (const block_q5_K *) vx; @@ -544,43 +442,6 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; } -static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { - - const block_q5_K * x = (const block_q5_K *) vx; - - // iqs is in 0...248 in steps of 8 => - const int j = iqs / 64; // j is in 0...3 - const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4 - const int is = 2*j; // is is in 0...6 in steps of 2 - - const float * y = yy + 64*j + ir; - const uint8_t * ql = x[ib].qs + 32*j + ir; - const uint8_t * qh = x[ib].qh + ir; - - const float dall = x[ib].d; - const float dmin = x[ib].dmin; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[ib].scales, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[ib].scales, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - uint8_t hm = 1 << is; - float sum = 0; - for (int k = 0; k < 4; ++k) { - sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1); - } - hm <<= 1; - for (int k = 0; k < 4; ++k) { - sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2); - } - result = sum; - -} - static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { const block_q6_K * x = (const block_q6_K *) vx; @@ -606,31 +467,376 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } -static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { +static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const block_q6_K * x = (const block_q6_K *) vx; + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/8; // 0...15 - const int is = 8*ip; + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; - const float * y = yy + 128*ip + il; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; - const float d = x[ib].d; + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0, 1 + + const int n = 2; // iterations in the inner loop + const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - 8*im; // 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...28 in steps of 4 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp - const uint8_t * ql = x[ib].ql + 64*ip + il; - const uint8_t * qh = x[ib].qh + 32*ip + il; - const int8_t * sc = x[ib].scales + is; + for (int i = ix; i < num_blocks_per_row; i += 2) { - result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) - + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32) - + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32) - + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32) - + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32) - + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32) - + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32) - + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32); + const uint8_t * q1 = x[i].qs + q_offset; + const uint8_t * q2 = q1 + 64; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); + s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + //const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * ql2 = ql1 + 64; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = x[i].d; + const float dmin = x[i].dmin; + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0)); + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } } static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ @@ -712,46 +918,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } -template <int n_thread, dot_kernel_k_t dot_kernel> -static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = QK_K; - const int vals_per_iter = iter_stride / n_thread; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - float tmp = 0; // partial sum for thread in warp - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int ib = ib0 + col/QK_K; // x block index - const int iqs = col%QK_K; // x quant index - const int iybs = col - col%QK_K; // y block start index - - float v; - dot_kernel(vx, ib, iqs, y + iybs, v); - tmp += v; - } - - // sum up partial sums and write back result - __syncthreads(); -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (half *) vx; @@ -1094,43 +1260,34 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); + dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; + const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); + dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { |