aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2023-07-29 23:04:44 +0200
committerGitHub <noreply@github.com>2023-07-29 23:04:44 +0200
commit11f3ca06b8c66b0427aab0a472479da22553b472 (patch)
tree8e934ff0d93a78447d996b00561f7ff826c3533f
parent9baf9ef304f330009d5a93b7390280a0fd27c9a1 (diff)
CUDA: Quantized matrix matrix multiplication (#2160)
* mmq implementation for non k-quants * q6_K * q2_K * q3_k * q4_K * vdr * q5_K * faster q8_1 loading * loop unrolling * add __restrict__ * q2_K sc_high * GGML_CUDA_MMQ_Y * Updated Makefile * Update Makefile * DMMV_F16 -> F16 * Updated README, CMakeLists * Fix CMakeLists.txt * Fix CMakeLists.txt * Fix multi GPU out-of-bounds
-rw-r--r--CMakeLists.txt8
-rw-r--r--Makefile15
-rw-r--r--README.md6
-rw-r--r--ggml-cuda.cu1546
4 files changed, 1273 insertions, 302 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c43e65e..6e1abea 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -67,7 +67,9 @@ endif()
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_BLAS "llama: use BLAS" OFF)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
-option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
+option(LLAMA_CUBLAS "llama: use CUDA" OFF)
+option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF)
+set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels")
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
@@ -251,6 +253,10 @@ if (LLAMA_CUBLAS)
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
+ if (LLAMA_CUDA_CUBLAS)
+ add_compile_definitions(GGML_CUDA_CUBLAS)
+ endif()
+ add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y})
if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif()
diff --git a/Makefile b/Makefile
index 2035c52..3d1fff8 100644
--- a/Makefile
+++ b/Makefile
@@ -194,7 +194,7 @@ ifdef LLAMA_CUBLAS
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
- NVCCFLAGS = --forward-unknown-to-host-compiler
+ NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
else
@@ -220,14 +220,25 @@ else ifdef LLAMA_CUDA_DMMV_Y
else
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
endif # LLAMA_CUDA_MMV_Y
+ifdef LLAMA_CUDA_F16
+ NVCCFLAGS += -DGGML_CUDA_F16
+endif # LLAMA_CUDA_F16
ifdef LLAMA_CUDA_DMMV_F16
- NVCCFLAGS += -DGGML_CUDA_DMMV_F16
+ NVCCFLAGS += -DGGML_CUDA_F16
endif # LLAMA_CUDA_DMMV_F16
ifdef LLAMA_CUDA_KQUANTS_ITER
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
else
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
+ifdef LLAMA_CUDA_MMQ_Y
+ NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
+else
+ NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
+endif # LLAMA_CUDA_MMQ_Y
+ifdef LLAMA_CUDA_CUBLAS
+ NVCCFLAGS += -DGGML_CUDA_CUBLAS
+endif # LLAMA_CUDA_CUBLAS
ifdef LLAMA_CUDA_CCBIN
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
endif
diff --git a/README.md b/README.md
index 6a3268d..42fc42b 100644
--- a/README.md
+++ b/README.md
@@ -402,10 +402,12 @@ Building the program with BLAS support may lead to some performance improvements
| Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------|
+ | LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). |
+ | LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
- | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
- | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
+ | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
+ | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
- #### CLBlast
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 511f48c..0a43fb5 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -52,13 +52,41 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0)
#endif // CUDART_VERSION >= 11
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
-#endif //GGML_CUDA_DMMV_F16
+#endif //GGML_CUDA_F16
+
+static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
+ const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
+ const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
+ return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
+ return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
@@ -87,8 +115,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
#define QR4_1 2
#define QI4_1 (QK4_1 / (4 * QR4_1))
typedef struct {
- half d; // delta
- half m; // min
+ half2 dm; // dm.x = delta, dm.y = min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
@@ -107,8 +134,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
#define QR5_1 2
#define QI5_1 (QK5_1 / (4 * QR5_1))
typedef struct {
- half d; // delta
- half m; // min
+ half2 dm; // dm.x = delta, dm.y = min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
@@ -127,13 +153,19 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
#define QR8_1 1
#define QI8_1 (QK8_1 / (4 * QR8_1))
typedef struct {
- half d; // delta
- half s; // unquantized sum
+ half2 ds; // ds.x = delta, ds.y = sum
int8_t qs[QK8_0]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
-typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
+typedef void (*load_tiles_cuda_t)(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row);
+typedef float (*vec_dot_q_mul_mat_cuda_t)(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
//================================= k-quants
@@ -150,8 +182,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
+ half2 dm; // super-block scale for quantized scales/mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
@@ -180,8 +211,7 @@ typedef struct {
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else
typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
+ half2 dm; // super-block scale for quantized scales/mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
@@ -200,11 +230,10 @@ typedef struct {
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else
typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
- uint8_t qh[QK_K/8]; // quants, high bit
- uint8_t qs[QK_K/2]; // quants, low 4 bits
+ half2 dm; // super-block scale for quantized scales/mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
@@ -233,6 +262,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#ifndef GGML_CUDA_MMQ_Y
+#define GGML_CUDA_MMQ_Y 64
+#endif // GGML_CUDA_MMQ_Y
+
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
@@ -367,33 +400,33 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
v.x = vui & 0xF;
v.y = vui >> 4;
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
v = __hsub2(v, {8.0f, 8.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
- const dfloat d = x[ib].d;
- const dfloat m = x[ib].m;
+ const dfloat d = x[ib].dm.x;
+ const dfloat m = x[ib].dm.y;
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@@ -410,20 +443,20 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
v = __hsub2(v, {16.0f, 16.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
- const dfloat d = x[ib].d;
- const dfloat m = x[ib].m;
+ const dfloat d = x[ib].dm.x;
+ const dfloat m = x[ib].dm.y;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -434,13 +467,13 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
@@ -451,12 +484,12 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
#else
v.x *= d;
v.y *= d;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
//================================== k-quants
@@ -475,8 +508,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
- float dall = x[i].d;
- float dmin = x[i].dmin;
+ float dall = x[i].dm.x;
+ float dmin = x[i].dm.y;
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@@ -486,8 +519,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il;
- float dall = x[i].d;
- float dmin = x[i].dmin;
+ float dall = x[i].dm.x;
+ float dmin = x[i].dm.y;
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif
@@ -573,8 +606,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + n*ir;
- const float dall = x[i].d;
- const float dmin = x[i].dmin;
+ const float dall = x[i].dm.x;
+ const float dmin = x[i].dm.y;
const uint8_t * q = x[i].qs + 32*il + n*ir;
@@ -612,8 +645,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + 2*ir;
- const float dall = x[i].d;
- const float dmin = x[i].dmin;
+ const float dall = x[i].dm.x;
+ const float dmin = x[i].dm.y;
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
@@ -725,8 +758,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
- const float dall = x[i].d;
- const float dmin = x[i].dmin;
+ const float dall = x[i].dm.x;
+ const float dmin = x[i].dm.y;
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
@@ -768,9 +801,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
uaux[0] = s[0] & 0x0f0f0f0f;
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
- const half2 * dh = (const half2 *)&x[i].d;
-
- const float2 dall = __half22float2(dh[0]);
+ const float2 dall = __half22float2(x[i].dm);
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
@@ -948,8 +979,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
- const float dall = x[i].d;
- const float dmin = x[i].dmin;
+ const float dall = x[i].dm.x;
+ const float dmin = x[i].dm.y;
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
@@ -1081,8 +1112,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
- const float dall = x[i].d;
- const float dmin = x[i].dmin;
+ const float dall = x[i].dm.x;
+ const float dmin = x[i].dm.y;
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
@@ -1270,19 +1301,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
v.y = x[ib + iqs + 1];
}
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
- const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
+ const int ix = blockDim.x*blockIdx.x + threadIdx.x;
- if (i >= k) {
+ if (ix >= kx_padded) {
return;
}
+ const int iy = blockDim.y*blockIdx.y + threadIdx.y;
+
+ const int i_padded = iy*kx_padded + ix;
+
block_q8_1 * y = (block_q8_1 *) vy;
- const int ib = i / QK8_1; // block index
- const int iqs = i % QK8_1; // quant index
+ const int ib = i_padded / QK8_1; // block index
+ const int iqs = i_padded % QK8_1; // quant index
- const float xi = i < ndata ? x[i] : 0.0f;
+ const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
@@ -1301,8 +1336,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
return;
}
- y[ib].d = d;
- y[ib].s = sum;
+ y[ib].ds.x = d;
+ y[ib].ds.y = sum;
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1326,45 +1361,79 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
y[iybs + iqs + y_offset] = v.y;
}
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
- int vi;
- memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
- const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
- const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
+#define VDR_q4_0_q8_1 1
- const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+ const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
// subtract 8 from each quantized value
- const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
- const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
+ const int vi0 = (vi >> 0) & 0x0F0F0F0F;
+ const int vi1 = (vi >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
int sumi = __dp4a(vi0, ui0, 0);
sumi = __dp4a(vi1, ui1, sumi);
- return sumi*d;
+ return __half2float(d4) * (sumi * __half2float(ds8.x) - (8/QI4_0) * __half2float(ds8.y));
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+ const int vi = get_int_from_uint8(bq4_0->qs, iqs);
+ const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
+ const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_0);
+
+ return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0)];
+
+ *x_ql = tile_x_qs;
+ *x_dm = tile_x_d;
+}
+
+static __device__ __forceinline__ void load_tiles_q4_0(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI4_0;
+ const int kqsx = k % QI4_0;
+
+ const block_q4_0 * bx = ((block_q4_0 *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI4_0) + kbx].x = bx->d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+ return vec_dot_q4_0_q8_1_impl(
+ x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
+ x_dm[i * (WARP_SIZE/QI4_0) + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
- const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
- const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
- const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
+#define VDR_q4_1_q8_1 1
- const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
- const float m = bq4_1->m;
- const float s = bq8_1->s;
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+ const int & vi, const int & ui0, const int & ui1, const half2 & dm4, const half2 & ds8) {
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
@@ -1372,184 +1441,421 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
int sumi = __dp4a(vi0, ui0, 0);
sumi = __dp4a(vi1, ui1, sumi);
- return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
+#ifdef GGML_CUDA_F16
+ const half2 tmp = __hmul2(dm4, ds8);
+ const float d4d8 = __half2float(tmp.x);
+ const float m4s8 = __half2float(tmp.y);
+#else
+ const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x);
+ const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y);
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI8_1/QR4_1 to compensate for multiple threads adding it
+ return sumi * d4d8 + m4s8 / (QI8_1 / QR4_1);
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+ const int vi = get_int_from_uint8_aligned(bq4_1->qs, iqs);
+ const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
+ const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI4_1);
+
+ return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1)];
+
+ *x_ql = tile_x_qs;
+ *x_dm = tile_x_dm;
+}
+
+static __device__ __forceinline__ void load_tiles_q4_1(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI4_1;
+ const int kqsx = k % QI4_1;
+
+ const block_q4_1 * bx = ((block_q4_1 *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI4_1) + kbx] = bx->dm;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+ return vec_dot_q4_1_q8_1_impl(
+ x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
+ x_dm[i * (WARP_SIZE/QI4_1) + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+#define VDR_q5_0_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+ const int & qs, const int & qh, const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (qh << 25) & 0x10000000; // 3 -> 28
+ int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+
+ int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (qh << 9) & 0x10000000; // 19 -> 28
+ sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+
+ return __half2float(d5) * (sumi*__half2float(ds8.x) - (16/QI5_0) * __half2float(ds8.y));
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
- int qs;
- memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
- const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
- const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
- const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
- const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
-
- const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
-
- int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
- vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
- vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
- vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
- vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
- vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
+ const int qs = get_int_from_uint8(bq5_0->qs, iqs);
+ const int qh = get_int_from_uint8(bq5_0->qh, 0) >> (4 * iqs);
+ const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
+ const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_0);
+
+ return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
+ __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
+
+ *x_ql = tile_x_ql;
+ *x_qh = tile_x_qh;
+ *x_dm = tile_x_d;
+}
+
+static __device__ __forceinline__ void load_tiles_q5_0(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI5_0;
+ const int kqsx = k % QI5_0;
+
+ const block_q5_0 * bx = ((block_q5_0 *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
+ x_qh[i * (WARP_SIZE / QI5_0) + kbx] = get_int_from_uint8(bx->qh, 0);
+ x_dm[i * (WARP_SIZE / QI5_0) + kbx].x = bx->d;
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+ const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
+
+ return vec_dot_q5_0_q8_1_impl(
+ x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs],
+ y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+#define VDR_q5_1_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+ const int & qs, const int & qh, const int & ui0, const int & ui1, const half2 & dm5, const half2 & ds8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
+ vi0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (qh << 25) & 0x10000000; // 3 -> 28
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
- int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
- vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
- vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
- vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
- vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
- vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
+ int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
+ vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (qh << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
- return sumi*d;
+#ifdef GGML_CUDA_F16
+ const half2 tmp = __hmul2(dm5, ds8);
+ const float d5d8 = __half2float(tmp.x);
+ const float m5s8 = __half2float(tmp.y);
+#else
+ const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x);
+ const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y);
+#endif // GGML_CUDA_F16
+
+ return sumi*d5d8 + m5s8/QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
+
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
- const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
- const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
- const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
- const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
- const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
-
- const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
- const float m = bq5_1->m;
- const float s = bq8_1->s;
-
- int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
- vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
- vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
- vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
- vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
- int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+ const int qs = get_int_from_uint8_aligned(bq5_1->qs, iqs);
+ const int qh = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * iqs);
+ const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
+ const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_1);
- int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
- vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
- vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
- vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
- vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
- sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+ return vec_dot_q5_1_q8_1_impl(qs, qh, ui0, ui1, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
+
+ *x_ql = tile_x_ql;
+ *x_qh = tile_x_qh;
+ *x_dm = tile_x_dm;
+}
- return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
+static __device__ __forceinline__ void load_tiles_q5_1(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI5_1;
+ const int kqsx = k % QI5_1;
+
+ const block_q5_1 * bx = ((block_q5_1 *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
+ x_qh[i * (WARP_SIZE / QI5_1) + kbx] = get_int_from_uint8(bx->qh, 0);
+ x_dm[i * (WARP_SIZE / QI5_1) + kbx] = bx->dm;
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+ const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
+
+ return vec_dot_q5_1_q8_1_impl(
+ x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2*WARP_SIZE) + kyqs],
+ y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+}
+
+#define VDR_q8_0_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl(
+ const int & vi, const int & ui, const half & d8_0, const half2 & ds8_1) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ // SIMD dot product of quantized values
+ const int sumi = __dp4a(vi, ui, 0);
+
+ return sumi * __half2float(d8_0) * __half2float(ds8_1.x);
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
- int vi;
- memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
- const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+ const int vi = get_int_from_int8(bq8_0->qs, iqs);
+ const int ui = get_int_from_int8_aligned(bq8_1->qs, iqs);
- const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
+ return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds);
+}
- // SIMD dot product of quantized values
- int sumi = __dp4a(vi, ui, 0);
+static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0)];
+
+ *x_ql = tile_x_qs;
+ *x_dm = tile_x_d;
+}
+
+static __device__ __forceinline__ void load_tiles_q8_0(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI8_0;
+ const int kqsx = k % QI8_0;
+
+ const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d;
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ return vec_dot_q8_0_q8_1_impl(
+ x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k],
+ x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+}
+
+#define VDR_q2_K_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
+ const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const half2 & dm, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+ for (int i = 0; i < QR2_K; ++i) {
+ const int sc = scales[2*i];
+
+ const int vi = (v >> (2*i)) & 0x03030303;
+
+ sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+ int sc_high = sc >> 4;
+ sc_high |= sc_high << 8;
+ sc_high |= sc_high << 16;
+ sumf_m += d8[i] * __dp4a(sc_high, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+ }
- return sumi*d;
+ const float2 dmf = __half22float2(dm);
+
+ return dmf.x*sumf_d - dmf.y*sumf_m;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q2_K * bq2_K = (const block_q2_K *) vbq;
const int bq8_offset = QR2_K * (iqs / QI8_1);
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
+ const uint8_t * scales = bq2_K->scales + scale_offset;
- const float d = bq2_K->d;
- const float dmin = bq2_K->dmin;
+ const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+ int u[QR2_K];
+ float d8[QR2_K];
- const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]);
+ for (int i = 0; i < QR2_K; ++ i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + i].ds.x;
+ }
- for (int i = 0; i < QR2_K; ++i) {
- const int sc = bq2_K->scales[scale_offset + 2*i];
+ return vec_dot_q2_K_q8_1_impl(v, u, scales, bq2_K->dm, d8);
+}
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const float d8i = bq8i->d;
+static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- const int vi = (v >> (2*i)) & 0x03030303;
- const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
+ __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
- sumf_d += d8i * (__dp4a(vi, ui, 0) * (sc & 0xF)); // SIMD dot product
- sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values
- }
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_sc = tile_x_sc;
+}
- return d*sumf_d - dmin*sumf_m;
-#else
- return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+static __device__ __forceinline__ void load_tiles_q2_K(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI2_K;
+ const int kqsx = k % QI2_K;
+
+ const block_q2_K * bx = ((block_q2_K *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI2_K) + kbx] = bx->dm;
+ x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->scales, kqsx / 4);
}
-static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+ __builtin_assume(i < GGML_CUDA_MMQ_Y);
+ __builtin_assume(j < WARP_SIZE);
+ __builtin_assume(k < WARP_SIZE);
- const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
- const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+ const int kbx = k / QI2_K;
+ const int kqsx = k % QI2_K;
- float sumf = 0.0f;
+ const int bq8_offset = QR2_K * (kqsx / QI8_1);
+ const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
- const float d = bq3_K->d;
+ const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16 + scale_offset;
+
+ int u[QR2_K];
+ float d8[QR2_K];
- int vl;
- memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
+ for (int l = 0; l < QR2_K; ++ l) {
+ const int y_qs_index = j * (QR2_K*WARP_SIZE) + kbx * (QR2_K*QI2_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
+ u[l] = y_qs[y_qs_index];
+ d8[l] = y_ds[y_qs_index / QI8_1].x;
+ }
+
+ return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + kbx], d8);
+}
- int vh;
- memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
- vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted
- vh >>= bq8_offset;
+#define VDR_q3_K_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl(
+ const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const int & scale_offset, const float & d, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ float sumf = 0.0f;
for (int i = 0; i < QR3_K; ++i) {
const int isc = scale_offset + 2*i;
const int isc_low = isc % (QK_K/32);
const int sc_shift_low = 4 * (isc / (QK_K/32));
- const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
const int isc_high = isc % (QK_K/64);
const int sc_shift_high = 2 * (isc / (QK_K/64));
- const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
const int sc = (sc_low | sc_high) - 32;
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
- const float d8i = bq8i->d;
-
const int vil = (vl >> (2*i)) & 0x03030303;
const int vih = ((vh >> i) << 2) & 0x04040404;
const int vi = __vsubss4(vil, vih);
- sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+ sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d*sumf;
@@ -1558,31 +1864,136 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
-static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
- const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+ const float d = bq3_K->d;
+
+ const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+ for (int i = 0; i < QR3_K; ++i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + i].ds.x;
+ }
+
+ return vec_dot_q3_K_q8_1_impl(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
+ __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE / 2)];
+ __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
+
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_qh = tile_x_qh;
+ *x_sc = tile_x_sc;
+}
+
+static __device__ __forceinline__ void load_tiles_q3_K(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI3_K;
+ const int kqsx = k % QI3_K;
+
+ const block_q3_K * bx = ((block_q3_K *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI3_K) + kbx].x = bx->d;
+ x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->hmask, kqsx / 2);
+ x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8(bx->scales, kqsx / 4);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ const int kbx = k / QI3_K;
+ const int kqsx = k % QI3_K;
+
+ const int bq8_offset = QR3_K * (kqsx / (QI3_K/2));
+ const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
+
+ const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16;
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~x_qh[i * (WARP_SIZE/2) + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+ for (int l = 0; l < QR3_K; ++ l) {
+ const int y_qs_index = j * (QR3_K*WARP_SIZE) + kbx * (QR3_K*QI3_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1;
+ u[l] = y_qs[y_qs_index];
+ d8[l] = y_ds[y_qs_index / QI8_1].x;
+ }
+
+ return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI3_K) + kbx].x, d8);
+}
+
+#define VDR_q4_K_q8_1 2
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
float sumf_d = 0.0f;
float sumf_m = 0.0f;
+ for (int i = 0; i < QR4_K; ++i) {
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+ const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ }
+
+ return __half2float(dm4.x)*sumf_d - __half2float(dm4.y)*sumf_m;
+
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+
#ifndef GGML_QKK_64
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
- const float d = bq4_K->d;
- const float dmin = bq4_K->dmin;
-
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
- const int v1 = q4[0];
- const int v2 = q4[4];
+ v[0] = q4[0];
+ v[1] = q4[4];
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
uint16_t aux[2];
@@ -1598,27 +2009,24 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const uint8_t * m = sc + 2;
for (int i = 0; i < QR4_K; ++i) {
-
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const float d8i = bq8i->d;
- const int * q8 = (const int *)bq8i->qs + (iqs%4);
- const int ui1 = q8[0];
- const int ui2 = q8[4];
-
- const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
- const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
+ d8[i] = bq8i->ds.x;
- const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
- const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
-
- sumf_d += d8i * (dot1 * sc[i]);
- sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ const int * q8 = (const int *)bq8i->qs + (iqs%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
}
- return d*sumf_d - dmin*sumf_m;
+ return vec_dot_q4_K_q8_1_impl(v, u, sc, m, bq4_K->dm, d8);
#else
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;
@@ -1629,8 +2037,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const float dall = bq4_K->d[0];
const float dmin = bq4_K->d[1];
- const float d8_1 = bq8_1[0].d;
- const float d8_2 = bq8_1[1].d;
+ const float d8_1 = bq8_1[0].ds.x;
+ const float d8_2 = bq8_1[1].ds.x;
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
@@ -1651,7 +2059,111 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
return dall * sumf_d - dmin * sumf_m;
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
#endif
+}
+
+static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
+ __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
+
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_sc = tile_x_sc;
+}
+
+static __device__ __forceinline__ void load_tiles_q4_K(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI4_K;
+ const int kqsx = k % QI4_K;
+
+ const block_q4_K * bx = ((block_q4_K *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm;
+ x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32));
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ __builtin_assume(i < GGML_CUDA_MMQ_Y);
+ __builtin_assume(j < WARP_SIZE);
+ __builtin_assume(k < WARP_SIZE);
+
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * (kqsx / (QI8_1/2));
+
+ v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
+ v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
+
+ const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)];
+ uint16_t aux[2];
+ const int l = bq8_offset/2;
+ if (l < 2) {
+ aux[0] = scales[l+0] & 0x3f3f;
+ aux[1] = scales[l+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int l = 0; l < QR4_K; ++l) {
+ const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2);
+ u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
+ u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
+ d8[l] = y_ds[kqsy / QI8_1].x;
+ }
+
+ return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8);
+}
+
+#define VDR_q5_K_q8_1 2
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
+ const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+ for (int i = 0; i < QR5_K; ++i) {
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+ const int v0i = vl0i | vh0i;
+ const int v1i = vl1i | vh1i;
+
+ const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+ const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]);
+
+ }
+
+ return __half2float(dm5.x)*sumf_d - __half2float(dm5.y)*sumf_m;
#else
return 0.0f; // only to satisfy the compiler
@@ -1659,28 +2171,25 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#ifndef GGML_QKK_64
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
-#ifndef GGML_QKK_64
+ int vl[2];
+ int vh[2];
+ int u[2*QR5_K];
+ float d8[QR5_K];
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
- float sumf_d = 0.0f;
- float sumf_m = 0.0f;
-
- const float d = bq5_K->d;
- const float dmin = bq5_K->dmin;
-
- const int vl1 = ql[0];
- const int vl2 = ql[4];
+ vl[0] = ql[0];
+ vl[1] = ql[4];
- const int vh1 = qh[0] >> bq8_offset;
- const int vh2 = qh[4] >> bq8_offset;
+ vh[0] = qh[0] >> bq8_offset;
+ vh[1] = qh[4] >> bq8_offset;
const uint16_t * scales = (const uint16_t *)bq5_K->scales;
uint16_t aux[2];
@@ -1696,40 +2205,27 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const uint8_t * m = sc + 2;
for (int i = 0; i < QR5_K; ++i) {
-
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const float d8i = bq8i->d;
- const int * q8 = (const int *)bq8i->qs + (iqs%4);
- const int ui1 = q8[0];
- const int ui2 = q8[4];
-
- const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
- const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
-
- const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
- const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
-
- const int vi1 = vil1 | vih1;
- const int vi2 = vil2 | vih2;
-
- const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
- const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
-
- sumf_d += d8i * (dot1 * sc[i]);
- sumf_m += d8i * (dot2 * m[i]);
+ d8[i] = bq8i->ds.x;
+ const int * q8 = (const int *)bq8i->qs + (iqs%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
}
- return d*sumf_d - dmin*sumf_m;
+ return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, bq5_K->dm, d8);
#else
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
const int8_t * s = bq5_K->scales;
const float d = bq5_K->d;
- const float d8_1 = bq8_1[0].d;
- const float d8_2 = bq8_1[1].d;
+ const float d8_1 = bq8_1[0].ds.x;
+ const float d8_2 = bq8_1[1].ds.x;
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
@@ -1755,56 +2251,304 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
return d * sumf_d;
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
#endif
+}
+
+static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
+ __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)];
+ __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
+
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_qh = tile_x_qh;
+ *x_sc = tile_x_sc;
+}
+
+static __device__ __forceinline__ void load_tiles_q5_K(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI5_K;
+ const int kqsx = k % QI5_K;
+
+ const block_q5_K * bx = ((block_q5_K *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
+ x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm;
+ x_qh[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->qh, kqsx/4);
+ x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32));
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ __builtin_assume(i < 2*WARP_SIZE);
+ __builtin_assume(j < WARP_SIZE);
+ __builtin_assume(k < WARP_SIZE);
+
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+ int vl[2];
+ int vh[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ const int bq8_offset = QR5_K * (kqsx / (QI8_1/2));
+
+ vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
+ vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
+
+ vh[0] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 0] >> bq8_offset;
+ vh[1] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 4] >> bq8_offset;
+
+ const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)];
+ uint16_t aux[2];
+ const int l = bq8_offset/2;
+ if (l < 2) {
+ aux[0] = scales[l+0] & 0x3f3f;
+ aux[1] = scales[l+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int l = 0; l < QR5_K; ++l) {
+ const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + kqsx % (QI8_1/2);
+ u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)];
+ u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)];
+ d8[l] = y_ds[kqsy / QI8_1].x;
+ }
+
+ return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8);
+}
+
+#define VDR_q6_K_q8_1 1
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl(
+ const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d, const float * __restrict__ d8) {
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ float sumf = 0.0f;
+
+ for (int i = 0; i < QR6_K; ++i) {
+ const int sc = scales[4*i];
+
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+ const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+ sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d*sumf;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
- const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
-#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q6_K * bq6_K = (const block_q6_K *) vbq;
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
- float sumf = 0.0f;
-
- const float d = bq6_K->d;
+ const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+ const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
- int vl;
- memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int));
+ const int8_t * scales = bq6_K->scales + scale_offset;
- int vh;
- memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int));
+ int u[QR6_K];
+ float d8[QR6_K];
for (int i = 0; i < QR6_K; ++i) {
- const int sc = bq6_K->scales[scale_offset + 4*i];
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
+ }
- const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
- const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
- const float d8i = bq8i->d;
+ return vec_dot_q6_K_q8_1_impl(vl, vh, u, scales, bq6_K->d, d8);
+}
- const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030;
+ __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
+ __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K)];
+ __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)];
+ __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)];
- const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_qh = tile_x_qh;
+ *x_sc = tile_x_sc;
+}
+
+static __device__ __forceinline__ void load_tiles_q6_K(
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
+ int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
+
+ const int kbx = k / QI6_K;
+ const int kqsx = k % QI6_K;
+
+ const block_q6_K * bx = ((block_q6_K *) vx) + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->ql, kqsx);
+ x_dm[i * (WARP_SIZE / QI6_K) + kbx].x = bx->d;
+ x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->qh, kqsx/2);
+ x_sc[i * (WARP_SIZE / 8) + k/8] = get_int_from_int8(bx->scales, kqsx/8);
+}
- sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
+
+ __builtin_assume(i < GGML_CUDA_MMQ_Y);
+ __builtin_assume(j < WARP_SIZE);
+ __builtin_assume(k < WARP_SIZE);
+
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+ const int bq8_offset = 2 * QR6_K * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/4);
+ const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8);
+ const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4));
+
+ const int vh = x_qh[i * (WARP_SIZE/2) + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift;
+
+ const int x_sc_offset = i * (WARP_SIZE/8) + kbx * (QI6_K/8);
+ const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset;
+
+ int u[QR6_K];
+ float d8[QR6_K];
+
+ for (int l = 0; l < QR6_K; ++l) {
+ const int kqsy = j * (QR6_K*WARP_SIZE) + kbx * (QR6_K*QI6_K) + (bq8_offset + 2*l)*QI8_1 + kqsx % QI8_1;
+ u[l] = y_qs[kqsy];
+ d8[l] = y_ds[kqsy / QI8_1].x;
}
- return d*sumf;
-#else
- return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+ return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, x_dm[i * (WARP_SIZE/QI6_K) + kbx].x, d8);
+}
+
+template <int qk, int qr, int qi, typename block_q_t,
+ allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
+static __global__ void mul_mat_q(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ const int blocks_per_row_x = ncols_x / qk;
+ const int blocks_per_col_y = nrows_y / QK8_1;
+ const int blocks_per_warp = WARP_SIZE / qi;
+
+ const int & ncols_dst = ncols_y;
+
+ const int tid_x = threadIdx.x;
+ const int tid_y = threadIdx.y;
+
+ const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y;
+ const int & row_x_0 = row_dst_0;
+ const int row_dst = row_dst_0 + tid_x;
+
+ const int col_dst_0 = blockIdx.y*WARP_SIZE;
+ const int & col_y_0 = col_dst_0;
+
+ int * tile_x_ql = nullptr;
+ half2 * tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+ allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
+
+ const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1;
+
+ __shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
+ __shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
+
+ float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f};
+
+ for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
+
+ for (int i = 0; i < GGML_CUDA_MMQ_Y; i += 8) {
+ load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
+ i + tid_y, tid_x, blocks_per_row_x);
+ }
+
+ for (int ir = 0; ir < qr; ++ir) {
+ const int kqs = ir*WARP_SIZE + tid_x;
+ const int kby = kqs / QI8_1;
+
+ for (int i = 0; i < WARP_SIZE; i += 8) {
+ const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
+
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby];
+
+ tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
+ }
+ }
+
+ for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) {
+ const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
+ const int kby = tid_x % blocks_per_tile_y_col;
+ const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
+ tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby] = y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
+ }
+
+ __syncthreads();
+
+#if __CUDA_ARCH__ >= 700 // TODO: actually test this with compute capability 7.X cards
+#pragma unroll
+#endif // __CUDA_ARCH__ >= 700
+ for (int k = 0; k < WARP_SIZE/vdr; ++k) {
+#pragma unroll
+ for (int j = 0; j < WARP_SIZE; j += 8) {
+#pragma unroll
+ for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
+ sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
+ tid_x + i, tid_y + j, k);
+ }
+ }
+ }
+
+ __syncthreads();
+ }
+
+
+ if (row_dst >= nrows_dst) {
+ return;
+ }
+
+ for (int j = 0; j < WARP_SIZE; j += 8) {
+ const int col_dst = col_dst_0 + j + tid_y;
+
+ if (col_dst >= ncols_dst) {
+ return;
+ }
+
+ for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
+ dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8];
+ }
+ }
}
-template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -1813,7 +2557,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
}
const int blocks_per_row = ncols / qk;
- const int blocks_per_warp = WARP_SIZE / qi;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
@@ -1822,11 +2566,11 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
- const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
+ const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
- const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx
+ const int iby = (i + threadIdx.x / (qi/vdr)) * qk/QK8_1; // y block index that aligns with ibx
- const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int
+ const int iqs = threadIdx.x % (qi/vdr); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
@@ -1859,11 +2603,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
const int y_offset = qr == 1 ? 1 : qk/2;
// partial sum for each thread
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
#else
float tmp = 0.0f;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
@@ -1883,7 +2627,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
// matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
@@ -1891,7 +2635,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
#else
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
}
@@ -1902,11 +2646,11 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
}
if (tid == 0) {
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
dst[row] = tmp.x + tmp.y;
#else
dst[row] = tmp;
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
}
@@ -2203,9 +2947,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
-static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
- quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
+ const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ const dim3 num_blocks(block_num_x, ky, 1);
+ const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -2366,7 +3112,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2375,7 +3121,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2384,7 +3130,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2393,7 +3139,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2402,7 +3148,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, vec_dot_q8_0_q8_1>
+ mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2411,7 +3157,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK_K, QI2_K, block_q2_K, vec_dot_q2_K_q8_1>
+ mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2420,7 +3166,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
+ mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2429,10 +3175,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
- // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
- // is better amortized.
- mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
+ mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2441,10 +3184,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
- // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
- // is better amortized.
- mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
+ mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2453,7 +3193,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
+ mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
@@ -2500,6 +3240,126 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}
+static void ggml_mul_mat_q4_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q4_1_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, VDR_q4_1_q8_1, vec_dot_q4_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, VDR_q5_0_q8_1, vec_dot_q5_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_1_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, VDR_q5_1_q8_1, vec_dot_q5_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q8_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0, VDR_q8_0_q8_1, vec_dot_q8_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q2_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK_K, QR2_K, QI2_K, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q3_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK_K, QR3_K, QI3_K, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q4_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK_K, QR4_K, QI4_K, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q5_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK_K, QR5_K, QI5_K, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
+static void ggml_mul_mat_q6_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
+
+ const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
+ const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
+ mul_mat_q<QK_K, QR6_K, QI6_K, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+}
+
static void ggml_mul_mat_p021_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
@@ -2965,6 +3825,83 @@ inline void ggml_cuda_op_rms_norm(
(void) i1;
}
+inline void ggml_cuda_op_mul_mat_q(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddq_i != nullptr);
+ GGML_ASSERT(src1_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t i01_diff = i01_high - i01_low;
+
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
+ const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff;
+
+ const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ?
+ ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
+ size_t as;
+ void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as);
+ quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main);
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q4_1:
+ ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_0:
+ ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_1:
+ ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q8_0:
+ ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q2_K:
+ ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q3_K:
+ ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q4_K:
+ ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_K:
+ ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ case GGML_TYPE_Q6_K:
+ ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ ggml_cuda_pool_free(src1_q8_1, as);
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
inline void ggml_cuda_op_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -3006,7 +3943,7 @@ inline void ggml_cuda_op_mul_mat_vec(
ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
size_t as;
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
- quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
+ quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main);
switch (src0->type) {
case GGML_TYPE_Q4_0:
@@ -3047,7 +3984,7 @@ inline void ggml_cuda_op_mul_mat_vec(
ggml_cuda_pool_free(src1_q8_1, as);
} else {
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
size_t ash;
dfloat * src1_dfloat = nullptr; // dfloat == half
@@ -3063,7 +4000,7 @@ inline void ggml_cuda_op_mul_mat_vec(
}
#else
dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
switch (src0->type) {
case GGML_TYPE_Q4_0:
@@ -3104,11 +4041,11 @@ inline void ggml_cuda_op_mul_mat_vec(
break;
}
-#ifdef GGML_CUDA_DMMV_F16
+#ifdef GGML_CUDA_F16
if (src1_convert_f16) {
ggml_cuda_pool_free(src1_dfloat, ash);
}
-#endif // GGML_CUDA_DMMV_F16
+#endif // GGML_CUDA_F16
}
(void) src1;
@@ -3363,7 +4300,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
int64_t row_low, row_high;
if (split) {
row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
+ row_low -= row_low % GGML_CUDA_MMQ_Y;
+
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
+ row_high -= row_high % GGML_CUDA_MMQ_Y;
} else {
row_low = 0;
row_high = nrows0*i02_divisor;
@@ -3717,7 +4657,16 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
} else {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+#ifdef GGML_CUDA_CUBLAS
+ const bool use_mul_mat_q = false;
+#else
+ const bool use_mul_mat_q = ggml_is_quantized(src0->type);
+#endif // GGML_CUDA_CUBLAS
+ if (use_mul_mat_q) {
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
+ } else {
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
+ }
}
} else {
GGML_ASSERT(false);
@@ -3827,7 +4776,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
row_high = nrows;
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
+ row_low -= row_low % GGML_CUDA_MMQ_Y;
+
row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
+ row_high -= row_high % GGML_CUDA_MMQ_Y;
} else {
GGML_ASSERT(false);
}