From 11f3ca06b8c66b0427aab0a472479da22553b472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 29 Jul 2023 23:04:44 +0200 Subject: CUDA: Quantized matrix matrix multiplication (#2160) * mmq implementation for non k-quants * q6_K * q2_K * q3_k * q4_K * vdr * q5_K * faster q8_1 loading * loop unrolling * add __restrict__ * q2_K sc_high * GGML_CUDA_MMQ_Y * Updated Makefile * Update Makefile * DMMV_F16 -> F16 * Updated README, CMakeLists * Fix CMakeLists.txt * Fix CMakeLists.txt * Fix multi GPU out-of-bounds --- CMakeLists.txt | 8 +- Makefile | 15 +- README.md | 6 +- ggml-cuda.cu | 1586 +++++++++++++++++++++++++++++++++++++++++++++----------- 4 files changed, 1293 insertions(+), 322 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c43e65e..6e1abea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,9 @@ endif() option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CUBLAS "llama: use CUDA" OFF) +option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) +set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels") option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") @@ -251,6 +253,10 @@ if (LLAMA_CUBLAS) set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + if (LLAMA_CUDA_CUBLAS) + add_compile_definitions(GGML_CUDA_CUBLAS) + endif() + add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y}) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() diff --git a/Makefile b/Makefile index 2035c52..3d1fff8 100644 --- a/Makefile +++ b/Makefile @@ -194,7 +194,7 @@ ifdef LLAMA_CUBLAS CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o - NVCCFLAGS = --forward-unknown-to-host-compiler + NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math ifdef LLAMA_CUDA_NVCC NVCC = $(LLAMA_CUDA_NVCC) else @@ -220,14 +220,25 @@ else ifdef LLAMA_CUDA_DMMV_Y else NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 endif # LLAMA_CUDA_MMV_Y +ifdef LLAMA_CUDA_F16 + NVCCFLAGS += -DGGML_CUDA_F16 +endif # LLAMA_CUDA_F16 ifdef LLAMA_CUDA_DMMV_F16 - NVCCFLAGS += -DGGML_CUDA_DMMV_F16 + NVCCFLAGS += -DGGML_CUDA_F16 endif # LLAMA_CUDA_DMMV_F16 ifdef LLAMA_CUDA_KQUANTS_ITER NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) else NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 endif +ifdef LLAMA_CUDA_MMQ_Y + NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y) +else + NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64 +endif # LLAMA_CUDA_MMQ_Y +ifdef LLAMA_CUDA_CUBLAS + NVCCFLAGS += -DGGML_CUDA_CUBLAS +endif # LLAMA_CUDA_CUBLAS ifdef LLAMA_CUDA_CCBIN NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) endif diff --git a/README.md b/README.md index 6a3268d..42fc42b 100644 --- a/README.md +++ b/README.md @@ -402,10 +402,12 @@ Building the program with BLAS support may lead to some performance improvements | Option | Legal values | Default | Description | |-------------------------|------------------------|---------|-------------| + | LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). | + | LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. | | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | - | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | + | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | - #### CLBlast diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 511f48c..0a43fb5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -52,13 +52,41 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float typedef half2 dfloat2; #else typedef float dfloat; // dequantize float typedef float2 dfloat2; -#endif //GGML_CUDA_DMMV_F16 +#endif //GGML_CUDA_F16 + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); @@ -87,8 +115,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 #define QR4_1 2 #define QI4_1 (QK4_1 / (4 * QR4_1)) typedef struct { - half d; // delta - half m; // min + half2 dm; // dm.x = delta, dm.y = min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); @@ -107,8 +134,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QR5_1 2 #define QI5_1 (QK5_1 / (4 * QR5_1)) typedef struct { - half d; // delta - half m; // min + half2 dm; // dm.x = delta, dm.y = min uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -127,13 +153,19 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define QR8_1 1 #define QI8_1 (QK8_1 / (4 * QR8_1)) typedef struct { - half d; // delta - half s; // unquantized sum + half2 ds; // ds.x = delta, ds.y = sum int8_t qs[QK8_0]; // quants } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs); +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); //================================= k-quants @@ -150,8 +182,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins + half2 dm; // super-block scale for quantized scales/mins } block_q2_K; static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); @@ -180,8 +211,7 @@ typedef struct { static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins + half2 dm; // super-block scale for quantized scales/mins uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; @@ -200,11 +230,10 @@ typedef struct { static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); #endif @@ -233,6 +262,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#ifndef GGML_CUDA_MMQ_Y +#define GGML_CUDA_MMQ_Y 64 +#endif // GGML_CUDA_MMQ_Y + // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X #define GGML_CUDA_DMMV_X 32 @@ -367,33 +400,33 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hsub2(v, {8.0f, 8.0f}); v = __hmul2(v, {d, d}); #else v.x = (v.x - 8.0f) * d; v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; + const dfloat d = x[ib].dm.x; + const dfloat m = x[ib].dm.y; const int vui = x[ib].qs[iqs]; v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); v = __hadd2(v, {m, m}); #else v.x = (v.x * d) + m; v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -410,20 +443,20 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hsub2(v, {16.0f, 16.0f}); v = __hmul2(v, {d, d}); #else v.x = (v.x - 16.0f) * d; v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; + const dfloat d = x[ib].dm.x; + const dfloat m = x[ib].dm.y; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -434,13 +467,13 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); v = __hadd2(v, {m, m}); #else v.x = (v.x * d) + m; v.y = (v.y * d) + m; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ @@ -451,12 +484,12 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 v = __hmul2(v, {d, d}); #else v.x *= d; v.y *= d; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } //================================== k-quants @@ -475,8 +508,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; - float dall = x[i].d; - float dmin = x[i].dmin; + float dall = x[i].dm.x; + float dmin = x[i].dm.y; y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); @@ -486,8 +519,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const int il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); float * y = yy + i*QK_K + 16*is + il; - float dall = x[i].d; - float dmin = x[i].dmin; + float dall = x[i].dm.x; + float dmin = x[i].dm.y; y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); #endif @@ -573,8 +606,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float float * y = yy + i*QK_K + 64*il + n*ir; - const float dall = x[i].d; - const float dmin = x[i].dmin; + const float dall = x[i].dm.x; + const float dmin = x[i].dm.y; const uint8_t * q = x[i].qs + 32*il + n*ir; @@ -612,8 +645,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float float * y = yy + i*QK_K + 64*il + 2*ir; - const float dall = x[i].d; - const float dmin = x[i].dmin; + const float dall = x[i].dm.x; + const float dmin = x[i].dm.y; const uint8_t * ql = x[i].qs + 32*il + 2*ir; const uint8_t * qh = x[i].qh + 2*ir; @@ -725,8 +758,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; - const float dall = x[i].d; - const float dmin = x[i].dmin; + const float dall = x[i].dm.x; + const float dmin = x[i].dm.y; const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); aux[0] = a[0] & 0x0f0f0f0f; @@ -768,9 +801,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, uaux[0] = s[0] & 0x0f0f0f0f; uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; - const half2 * dh = (const half2 *)&x[i].d; - - const float2 dall = __half22float2(dh[0]); + const float2 dall = __half22float2(x[i].dm); float sum1 = 0, sum2 = 0; for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { @@ -948,8 +979,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; - const float dall = x[i].d; - const float dmin = x[i].dmin; + const float dall = x[i].dm.x; + const float dmin = x[i].dm.y; const uint16_t * a = (const uint16_t *)x[i].scales; aux[0] = a[im+0] & kmask1; @@ -1081,8 +1112,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; - const float dall = x[i].d; - const float dmin = x[i].dmin; + const float dall = x[i].dm.x; + const float dmin = x[i].dm.y; const uint16_t * a = (const uint16_t *)x[i].scales; aux[0] = a[im+0] & kmask1; @@ -1270,19 +1301,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { + if (ix >= kx_padded) { return; } + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int i_padded = iy*kx_padded + ix; + block_q8_1 * y = (block_q8_1 *) vy; - const int ib = i / QK8_1; // block index - const int iqs = i % QK8_1; // quant index + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index - const float xi = i < ndata ? x[i] : 0.0f; + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -1301,8 +1336,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest return; } - y[ib].d = d; - y[ib].s = sum; + y[ib].ds.x = d; + y[ib].ds.y = sum; } template @@ -1326,45 +1361,79 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __ y[iybs + iqs + y_offset] = v.y; } -static __device__ __forceinline__ float vec_dot_q4_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called - int vi; - memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]); +#define VDR_q4_0_q8_1 1 - const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d); +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics // subtract 8 from each quantized value - const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808); - const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808); + const int vi0 = (vi >> 0) & 0x0F0F0F0F; + const int vi1 = (vi >> 4) & 0x0F0F0F0F; // SIMD dot product of quantized values int sumi = __dp4a(vi0, ui0, 0); sumi = __dp4a(vi1, ui1, sumi); - return sumi*d; + return __half2float(d4) * (sumi * __half2float(ds8.x) - (8/QI4_0) * __half2float(ds8.y)); #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q4_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + const int vi = get_int_from_uint8(bq4_0->qs, iqs); + const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs); + const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_0); - const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); + return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0)]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_d; +} - const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d); - const float m = bq4_1->m; - const float s = bq8_1->s; +static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx = ((block_q4_0 *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI4_0) + kbx].x = bx->d; +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + return vec_dot_q4_0_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], + x_dm[i * (WARP_SIZE/QI4_0) + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); +} + +#define VDR_q4_1_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int & vi, const int & ui0, const int & ui1, const half2 & dm4, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const int vi0 = (vi >> 0) & 0x0F0F0F0F; const int vi1 = (vi >> 4) & 0x0F0F0F0F; @@ -1372,184 +1441,421 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( int sumi = __dp4a(vi0, ui0, 0); sumi = __dp4a(vi1, ui1, sumi); - return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block +#ifdef GGML_CUDA_F16 + const half2 tmp = __hmul2(dm4, ds8); + const float d4d8 = __half2float(tmp.x); + const float m4s8 = __half2float(tmp.y); +#else + const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x); + const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y); +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/QR4_1 to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / QR4_1); #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q5_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + const int vi = get_int_from_uint8_aligned(bq4_1->qs, iqs); + const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs); + const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_1); + + return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1)]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + +static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx = ((block_q4_1 *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI4_1) + kbx] = bx->dm; +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + return vec_dot_q4_1_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], + x_dm[i * (WARP_SIZE/QI4_1) + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); +} + +#define VDR_q5_0_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int & qs, const int & qh, const int & ui0, const int & ui1, const half & d5, const half2 & ds8) { + #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (qh << 4) & 0x00000010; // 0 -> 4 + vi0 |= (qh << 11) & 0x00001000; // 1 -> 12 + vi0 |= (qh << 18) & 0x00100000; // 2 -> 20 + vi0 |= (qh << 25) & 0x10000000; // 3 -> 28 + int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (qh << 2) & 0x00100000; // 18 -> 20 + vi1 |= (qh << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + + return __half2float(d5) * (sumi*__half2float(ds8.x) - (16/QI5_0) * __half2float(ds8.y)); +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; - int qs; - memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2); - const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]); - - const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d); - - int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits - vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 - vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 - vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 - vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 - vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values + const int qs = get_int_from_uint8(bq5_0->qs, iqs); + const int qh = get_int_from_uint8(bq5_0->qh, 0) >> (4 * iqs); + const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs); + const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_0); + + return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)]; + + *x_ql = tile_x_ql; + *x_qh = tile_x_qh; + *x_dm = tile_x_d; +} + +static __device__ __forceinline__ void load_tiles_q5_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx = ((block_q5_0 *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx); + x_qh[i * (WARP_SIZE / QI5_0) + kbx] = get_int_from_uint8(bx->qh, 0); + x_dm[i * (WARP_SIZE / QI5_0) + kbx].x = bx->d; +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0; + + return vec_dot_q5_0_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs], + y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); +} + +#define VDR_q5_1_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int & qs, const int & qh, const int & ui0, const int & ui1, const half2 & dm5, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits + vi0 |= (qh << 4) & 0x00000010; // 0 -> 4 + vi0 |= (qh << 11) & 0x00001000; // 1 -> 12 + vi0 |= (qh << 18) & 0x00100000; // 2 -> 20 + vi0 |= (qh << 25) & 0x10000000; // 3 -> 28 int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values - int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits - vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 - vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 - vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 - vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 - vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits + vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (qh << 2) & 0x00100000; // 18 -> 20 + vi1 |= (qh << 9) & 0x10000000; // 19 -> 28 sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values - return sumi*d; +#ifdef GGML_CUDA_F16 + const half2 tmp = __hmul2(dm5, ds8); + const float d5d8 = __half2float(tmp.x); + const float m5s8 = __half2float(tmp.y); +#else + const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x); + const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y); +#endif // GGML_CUDA_F16 + + return sumi*d5d8 + m5s8/QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block + #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; - const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); - const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); - const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); - const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); - const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]); - - const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d); - const float m = bq5_1->m; - const float s = bq8_1->s; - - int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits - vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 - vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 - vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 - vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 - int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + const int qs = get_int_from_uint8_aligned(bq5_1->qs, iqs); + const int qh = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * iqs); + const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs); + const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_1); - int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits - vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 - vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 - vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 - vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 - sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + return vec_dot_q5_1_q8_1_impl(qs, qh, ui0, ui1, bq5_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)]; + + *x_ql = tile_x_ql; + *x_qh = tile_x_qh; + *x_dm = tile_x_dm; +} + +static __device__ __forceinline__ void load_tiles_q5_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx = ((block_q5_1 *) vx) + i*blocks_per_row + kbx; - return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx); + x_qh[i * (WARP_SIZE / QI5_1) + kbx] = get_int_from_uint8(bx->qh, 0); + x_dm[i * (WARP_SIZE / QI5_1) + kbx] = bx->dm; +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0; + + return vec_dot_q5_1_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2*WARP_SIZE) + kyqs], + y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); +} + +#define VDR_q8_0_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int & vi, const int & ui, const half & d8_0, const half2 & ds8_1) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + // SIMD dot product of quantized values + const int sumi = __dp4a(vi, ui, 0); + + return sumi * __half2float(d8_0) * __half2float(ds8_1.x); #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; - int vi; - memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); - const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int vi = get_int_from_int8(bq8_0->qs, iqs); + const int ui = get_int_from_int8_aligned(bq8_1->qs, iqs); - const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d); + return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds); +} - // SIMD dot product of quantized values - int sumi = __dp4a(vi, ui, 0); +static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0)]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_d; +} + +static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + + const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d; +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + return vec_dot_q8_0_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k], + x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +#define VDR_q2_K_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + int sc_high = sc >> 4; + sc_high |= sc_high << 8; + sc_high |= sc_high << 16; + sumf_m += d8[i] * __dp4a(sc_high, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dmf = __half22float2(dm); - return sumi*d; + return dmf.x*sumf_d - dmf.y*sumf_m; #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q2_K * bq2_K = (const block_q2_K *) vbq; const int bq8_offset = QR2_K * (iqs / QI8_1); const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const uint8_t * scales = bq2_K->scales + scale_offset; - const float d = bq2_K->d; - const float dmin = bq2_K->dmin; + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; - const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]); + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds.x; + } - for (int i = 0; i < QR2_K; ++i) { - const int sc = bq2_K->scales[scale_offset + 2*i]; + return vec_dot_q2_K_q8_1_impl(v, u, scales, bq2_K->dm, d8); +} - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const float d8i = bq8i->d; +static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - const int vi = (v >> (2*i)) & 0x03030303; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)]; - sumf_d += d8i * (__dp4a(vi, ui, 0) * (sc & 0xF)); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values - } + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} - return d*sumf_d - dmin*sumf_m; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +static __device__ __forceinline__ void load_tiles_q2_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx = ((block_q2_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI2_K) + kbx] = bx->dm; + x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->scales, kqsx / 4); } -static __device__ __forceinline__ float vec_dot_q3_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q3_K * bq3_K = (const block_q3_K *) vbq; + __builtin_assume(i < GGML_CUDA_MMQ_Y); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); - const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; - float sumf = 0.0f; + const int bq8_offset = QR2_K * (kqsx / QI8_1); + const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2); - const float d = bq3_K->d; + const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16 + scale_offset; + + int u[QR2_K]; + float d8[QR2_K]; + + for (int l = 0; l < QR2_K; ++ l) { + const int y_qs_index = j * (QR2_K*WARP_SIZE) + kbx * (QR2_K*QI2_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1; + u[l] = y_qs[y_qs_index]; + d8[l] = y_ds[y_qs_index / QI8_1].x; + } - int vl; - memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int)); + return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + kbx], d8); +} + +#define VDR_q3_K_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d, const float * __restrict__ d8) { - int vh; - memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int)); - vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted - vh >>= bq8_offset; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; for (int i = 0; i < QR3_K; ++i) { const int isc = scale_offset + 2*i; const int isc_low = isc % (QK_K/32); const int sc_shift_low = 4 * (isc / (QK_K/32)); - const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF; + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; const int isc_high = isc % (QK_K/64); const int sc_shift_high = 2 * (isc / (QK_K/64)); - const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; const int sc = (sc_low | sc_high) - 32; - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); - const float d8i = bq8i->d; - const int vil = (vl >> (2*i)) & 0x03030303; const int vih = ((vh >> i) << 2) & 0x04040404; const int vi = __vsubss4(vil, vih); - sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product } return d*sumf; @@ -1558,31 +1864,136 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q4_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_K * bq4_K = (const block_q4_K *) vbq; + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds.x; + } + + return vec_dot_q3_K_q8_1_impl(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE / 2)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q3_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx = ((block_q3_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI3_K) + kbx].x = bx->d; + x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->hmask, kqsx / 2); + x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8(bx->scales, kqsx / 4); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const int bq8_offset = QR3_K * (kqsx / (QI3_K/2)); + const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2); + + const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16; + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~x_qh[i * (WARP_SIZE/2) + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + + for (int l = 0; l < QR3_K; ++ l) { + const int y_qs_index = j * (QR3_K*WARP_SIZE) + kbx * (QR3_K*QI3_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1; + u[l] = y_qs[y_qs_index]; + d8[l] = y_ds[y_qs_index / QI8_1].x; + } + + return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI3_K) + kbx].x, d8); +} + +#define VDR_q4_K_q8_1 2 + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + return __half2float(dm4.x)*sumf_d - __half2float(dm4.y)*sumf_m; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + #ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 const int bq8_offset = QR4_K * (iqs / (QI8_1/2)); - const float d = bq4_K->d; - const float dmin = bq4_K->dmin; - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4)); - const int v1 = q4[0]; - const int v2 = q4[4]; + v[0] = q4[0]; + v[1] = q4[4]; const uint16_t * scales = (const uint16_t *)bq4_K->scales; uint16_t aux[2]; @@ -1598,60 +2009,161 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const uint8_t * m = sc + 2; for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const float d8i = bq8i->d; - const int * q8 = (const int *)bq8i->qs + (iqs%4); - const int ui1 = q8[0]; - const int ui2 = q8[4]; - - const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F; - const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F; - - const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + d8[i] = bq8i->ds.x; - sumf_d += d8i * (dot1 * sc[i]); - sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + const int * q8 = (const int *)bq8i->qs + (iqs%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; } - return d*sumf_d - dmin*sumf_m; + return vec_dot_q4_K_q8_1_impl(v, u, sc, m, bq4_K->dm, d8); #else +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + uint16_t aux16[2]; const uint8_t * s = (const uint8_t *)aux16; - const uint16_t * a = (const uint16_t *)bq4_K->scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->d[0]; + const float dmin = bq4_K->d[1]; + + const float d8_1 = bq8_1[0].ds.x; + const float d8_2 = bq8_1[1].ds.x; + + const int ui1 = *((const int *)bq8_1[0].qs + iqs); + const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); + const int ui3 = *((const int *)bq8_1[1].qs + iqs); + const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + + const int * q4 = (const int *)bq4_K->qs + iqs; + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif +} + +static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI4_K; + const int kqsx = k % QI4_K; + + const block_q4_K * bx = ((block_q4_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm; + x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32)); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + __builtin_assume(i < GGML_CUDA_MMQ_Y); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * (kqsx / (QI8_1/2)); + + v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0]; + v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4]; + + const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)]; + uint16_t aux[2]; + const int l = bq8_offset/2; + if (l < 2) { + aux[0] = scales[l+0] & 0x3f3f; + aux[1] = scales[l+2] & 0x3f3f; + } else { + aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2); + aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int l = 0; l < QR4_K; ++l) { + const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2); + u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)]; + u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)]; + d8[l] = y_ds[kqsy / QI8_1].x; + } + + return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8); +} + +#define VDR_q5_K_q8_1 2 - const float dall = bq4_K->d[0]; - const float dmin = bq4_K->d[1]; +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { - const float d8_1 = bq8_1[0].d; - const float d8_2 = bq8_1[1].d; +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; - const int ui1 = *((const int *)bq8_1[0].qs + iqs); - const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); - const int ui3 = *((const int *)bq8_1[1].qs + iqs); - const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4); + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; - const int * q4 = (const int *)bq4_K->qs + iqs; - const int v1 = q4[0]; - const int v2 = q4[4]; + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; - const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; - sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); - sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u - return dall * sumf_d - dmin * sumf_m; + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); -#endif + } + + return __half2float(dm5.x)*sumf_d - __half2float(dm5.y)*sumf_m; #else return 0.0f; // only to satisfy the compiler @@ -1659,28 +2171,25 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +#ifndef GGML_QKK_64 const block_q5_K * bq5_K = (const block_q5_K *) vbq; -#ifndef GGML_QKK_64 + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; const int bq8_offset = QR5_K * (iqs / (QI8_1/2)); const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4)); const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4)); - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - const float d = bq5_K->d; - const float dmin = bq5_K->dmin; + vl[0] = ql[0]; + vl[1] = ql[4]; - const int vl1 = ql[0]; - const int vl2 = ql[4]; - - const int vh1 = qh[0] >> bq8_offset; - const int vh2 = qh[4] >> bq8_offset; + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; const uint16_t * scales = (const uint16_t *)bq5_K->scales; uint16_t aux[2]; @@ -1696,40 +2205,27 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const uint8_t * m = sc + 2; for (int i = 0; i < QR5_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - const float d8i = bq8i->d; - const int * q8 = (const int *)bq8i->qs + (iqs%4); - const int ui1 = q8[0]; - const int ui2 = q8[4]; - - const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F; - const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F; - - const int vih1 = ((vh1 >> i) << 4) & 0x10101010; - const int vih2 = ((vh2 >> i) << 4) & 0x10101010; - - const int vi1 = vil1 | vih1; - const int vi2 = vil2 | vih2; - - const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - - sumf_d += d8i * (dot1 * sc[i]); - sumf_m += d8i * (dot2 * m[i]); + d8[i] = bq8i->ds.x; + const int * q8 = (const int *)bq8i->qs + (iqs%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; } - return d*sumf_d - dmin*sumf_m; + return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, bq5_K->dm, d8); #else +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + const int8_t * s = bq5_K->scales; const float d = bq5_K->d; - const float d8_1 = bq8_1[0].d; - const float d8_2 = bq8_1[1].d; + const float d8_1 = bq8_1[0].ds.x; + const float d8_2 = bq8_1[1].ds.x; const int ui1 = *((const int *)bq8_1[0].qs + iqs); const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4); @@ -1755,56 +2251,304 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( return d * sumf_d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + #endif +} + +static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q5_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI5_K; + const int kqsx = k % QI5_K; + + const block_q5_K * bx = ((block_q5_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm; + x_qh[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->qh, kqsx/4); + x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32)); +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + __builtin_assume(i < 2*WARP_SIZE); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + int vl[2]; + int vh[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + const int bq8_offset = QR5_K * (kqsx / (QI8_1/2)); + + vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0]; + vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4]; + + vh[0] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 0] >> bq8_offset; + vh[1] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)]; + uint16_t aux[2]; + const int l = bq8_offset/2; + if (l < 2) { + aux[0] = scales[l+0] & 0x3f3f; + aux[1] = scales[l+2] & 0x3f3f; + } else { + aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2); + aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int l = 0; l < QR5_K; ++l) { + const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2); + u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)]; + u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)]; + d8[l] = y_ds[kqsy / QI8_1].x; + } + + return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8); +} + +#define VDR_q6_K_q8_1 1 + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q6_K * bq6_K = (const block_q6_K *) vbq; const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); - float sumf = 0.0f; - - const float d = bq6_K->d; + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; - int vl; - memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int)); + const int8_t * scales = bq6_K->scales + scale_offset; - int vh; - memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int)); + int u[QR6_K]; + float d8[QR6_K]; for (int i = 0; i < QR6_K; ++i) { - const int sc = bq6_K->scales[scale_offset + 4*i]; + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + 2*i].ds.x; + } - const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i; - const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]); - const float d8i = bq8i->d; + return vec_dot_q6_K_q8_1_impl(vl, vh, u, scales, bq6_K->d, d8); +} - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; +static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K)]; + __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)]; + __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)]; - const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI6_K; + const int kqsx = k % QI6_K; + + const block_q6_K * bx = ((block_q6_K *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->ql, kqsx); + x_dm[i * (WARP_SIZE / QI6_K) + kbx].x = bx->d; + x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->qh, kqsx/2); + x_sc[i * (WARP_SIZE / 8) + k/8] = get_int_from_int8(bx->scales, kqsx/8); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + __builtin_assume(i < GGML_CUDA_MMQ_Y); + __builtin_assume(j < WARP_SIZE); + __builtin_assume(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const int bq8_offset = 2 * QR6_K * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)); + + const int vh = x_qh[i * (WARP_SIZE/2) + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift; + + const int x_sc_offset = i * (WARP_SIZE/8) + kbx * (QI6_K/8); + const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; - sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + for (int l = 0; l < QR6_K; ++l) { + const int kqsy = j * (QR6_K*WARP_SIZE) + kbx * (QR6_K*QI6_K) + (bq8_offset + 2*l)*QI8_1 + kqsx % QI8_1; + u[l] = y_qs[kqsy]; + d8[l] = y_ds[kqsy / QI8_1].x; } - return d*sumf; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, x_dm[i * (WARP_SIZE/QI6_K) + kbx].x, d8); +} + +template +static __global__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int tid_x = threadIdx.x; + const int tid_y = threadIdx.y; + + const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y; + const int & row_x_0 = row_dst_0; + const int row_dst = row_dst_0 + tid_x; + + const int col_dst_0 = blockIdx.y*WARP_SIZE; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1; + + __shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)]; + __shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col]; + + float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += 8) { + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + i + tid_y, tid_x, blocks_per_row_x); + } + + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + tid_x; + const int kby = kqs / QI8_1; + + for (int i = 0; i < WARP_SIZE; i += 8) { + const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby]; + + tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1); + } + } + + for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) { + const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE; + const int kby = tid_x % blocks_per_tile_y_col; + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby] = y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds; + } + + __syncthreads(); + +#if __CUDA_ARCH__ >= 700 // TODO: actually test this with compute capability 7.X cards +#pragma unroll +#endif // __CUDA_ARCH__ >= 700 + for (int k = 0; k < WARP_SIZE/vdr; ++k) { +#pragma unroll + for (int j = 0; j < WARP_SIZE; j += 8) { +#pragma unroll + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + tid_x + i, tid_y + j, k); + } + } + } + + __syncthreads(); + } + + + if (row_dst >= nrows_dst) { + return; + } + + for (int j = 0; j < WARP_SIZE; j += 8) { + const int col_dst = col_dst_0 + j + tid_y; + + if (col_dst >= ncols_dst) { + return; + } + + for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) { + dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8]; + } + } } -template +template static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.y*blockDim.y + threadIdx.y; @@ -1813,7 +2557,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = WARP_SIZE / qi; + const int blocks_per_warp = vdr * WARP_SIZE / qi; // partial sum for each thread float tmp = 0.0f; @@ -1822,11 +2566,11 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * const block_q8_1 * y = (const block_q8_1 *) vy; for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index + const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index - const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx + const int iby = (i + threadIdx.x / (qi/vdr)) * qk/QK8_1; // y block index that aligns with ibx - const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int + const int iqs = threadIdx.x % (qi/vdr); // x block quant index when casting the quants to int tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); } @@ -1859,11 +2603,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons const int y_offset = qr == 1 ? 1 : qk/2; // partial sum for each thread -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics #else float tmp = 0.0f; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; @@ -1883,7 +2627,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons // matrix multiplication // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 tmp += __hmul2(v, { y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset] @@ -1891,7 +2635,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons #else tmp += v.x * y[iybs + iqs + j/qr + 0]; tmp += v.y * y[iybs + iqs + j/qr + y_offset]; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } } @@ -1902,11 +2646,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } if (tid == 0) { -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 dst[row] = tmp.x + tmp.y; #else dst[row] = tmp; -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } } @@ -2203,9 +2947,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con rms_norm_f32<<>>(x, dst, ncols, eps); } -static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - quantize_q8_1<<>>(x, vy, ndata, k); +static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { + const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ky, 1); + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx, kx_padded); } static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -2366,7 +3112,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2375,7 +3121,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2384,7 +3130,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2393,7 +3139,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2402,7 +3148,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2411,7 +3157,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2420,7 +3166,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2429,10 +3175,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per - // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales - // is better amortized. - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2441,10 +3184,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per - // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales - // is better amortized. - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2453,7 +3193,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -2500,6 +3240,126 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static void ggml_mul_mat_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q4_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q8_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q2_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q3_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y; + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q + <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + static void ggml_mul_mat_p021_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, cudaStream_t stream) { @@ -2965,6 +3825,83 @@ inline void ggml_cuda_op_rms_norm( (void) i1; } +inline void ggml_cuda_op_mul_mat_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddq_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t i01_diff = i01_high - i01_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff; + + const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ? + ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + size_t as; + void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as); + quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main); + + switch (src0->type) { + case GGML_TYPE_Q4_0: + ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q4_1: + ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_0: + ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_1: + ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q2_K: + ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q3_K: + ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q5_K: + ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + break; + default: + GGML_ASSERT(false); + break; + } + + ggml_cuda_pool_free(src1_q8_1, as); + + (void) src1; + (void) dst; + (void) src0_ddf_i; + (void) i02; + (void) i1; +} + inline void ggml_cuda_op_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -3006,7 +3943,7 @@ inline void ggml_cuda_op_mul_mat_vec( ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; size_t as; void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); - quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main); + quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main); switch (src0->type) { case GGML_TYPE_Q4_0: @@ -3047,7 +3984,7 @@ inline void ggml_cuda_op_mul_mat_vec( ggml_cuda_pool_free(src1_q8_1, as); } else { // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 size_t ash; dfloat * src1_dfloat = nullptr; // dfloat == half @@ -3063,7 +4000,7 @@ inline void ggml_cuda_op_mul_mat_vec( } #else dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 switch (src0->type) { case GGML_TYPE_Q4_0: @@ -3104,11 +4041,11 @@ inline void ggml_cuda_op_mul_mat_vec( break; } -#ifdef GGML_CUDA_DMMV_F16 +#ifdef GGML_CUDA_F16 if (src1_convert_f16) { ggml_cuda_pool_free(src1_dfloat, ash); } -#endif // GGML_CUDA_DMMV_F16 +#endif // GGML_CUDA_F16 } (void) src1; @@ -3363,7 +4300,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_low, row_high; if (split) { row_low = id == 0 ? 0 : nrows0*g_tensor_split[id]; + row_low -= row_low % GGML_CUDA_MMQ_Y; + row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; + row_high -= row_high % GGML_CUDA_MMQ_Y; } else { row_low = 0; row_high = nrows0*i02_divisor; @@ -3717,7 +4657,16 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); +#ifdef GGML_CUDA_CUBLAS + const bool use_mul_mat_q = false; +#else + const bool use_mul_mat_q = ggml_is_quantized(src0->type); +#endif // GGML_CUDA_CUBLAS + if (use_mul_mat_q) { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); + } else { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); + } } } else { GGML_ASSERT(false); @@ -3827,7 +4776,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { row_high = nrows; } else if (backend == GGML_BACKEND_GPU_SPLIT) { row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; + row_low -= row_low % GGML_CUDA_MMQ_Y; + row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1]; + row_high -= row_high % GGML_CUDA_MMQ_Y; } else { GGML_ASSERT(false); } -- cgit v1.2.3