aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2023-06-26 19:43:07 +0300
committerGitHub <noreply@github.com>2023-06-26 19:43:07 +0300
commit6769e944c727c63612dcafbef52009d21ae00fff (patch)
tree987a35bf7f7c0e0947c85bc75cba047834fbccd5 /ggml-cuda.cu
parentcbebf61ca7584e9709265395f0127ae7fc0f1882 (diff)
k-quants : support for super-block size of 64 (#2001)
* k_quants: WIP super-blocks with 64 weights * k_quants: WIP super-blocks with 64 weights Q6_K scalar and AVX2 works * k_quants: WIP super-blocks with 64 weights Q4_K scalar and AVX2 works * k_quants: WIP super-blocks with 64 weights Q2_K scalar and AVX2 works. Q2_K is way too slow (it is actually slower than the scalar implementation) * k_quants: WIP super-blocks with 64 weights Q3_K scalar and AVX2 works. * k_quants: WIP super-blocks with 64 weights Q5_K scalar and AVX2 works, and with that all k_quants are done on AVX2 and scalar * k_quants: WIP super-blocks with 64 weights Q6_K working on CUDA. Cannot make it run quite as gast as with super-blocks with 256 weigths: 8% slower on 4080, 20% slower on the 1660 (but there we fit 1 less layer on the GPU because pf the larger model size), so some fraction of these 20% is due to that, * k_quants: WIP super-blocks with 64 weights Q4_K working on CUDA. ~10% slower on GTX-1660, 16% slower on 4080. * k_quants: WIP super-blocks with 64 weights Q2_K working on CUDA. ~3% slower on GTX-1660, 10% slower on 4080. * k_quants: WIP super-blocks with 64 weights Q3_K working on CUDA. * k_quants: WIP super-blocks with 64 weights Q5_K working on CUDA, and with this CUDA is done. * k_quants: WIP super-blocks with 64 weights Q6_K working on ARM_NEON * k_quants: WIP super-blocks with 64 weights Q4_K working on ARM_NEON, but quite a bit slower than 256 weights * k_quants: WIP super-blocks with 64 weights Q2_K working on ARM_NEON, but quite a bit slower than 256 weights * k_quants: WIP super-blocks with 64 weights Q3_K working on ARM_NEON, but quite a bit slower than 256 weights. * k_quants: WIP super-blocks with 64 weights Q5_K working on ARM_NEON, but quite a bit slower than 256 weights. With that, we have full support for ARM_NEON, although performance is not quite there. * k_quants: WIP super-blocks with 64 weights Slightly more efficient Q3_K and Q5_K * k_quants: WIP super-blocks with 64 weights Another small improvement for Q3_K and Q5_K on ARM_NEON * k_quants: WIP super-blocks with 64 weights Yet another speedup for Q5_K on ARM_NEON. We are now within 10% of the QK_K = 256 version. * k_quants: WIP super-blocks with 64 weights * We are able to pass preprocessor macros to the Metal compiler * Q6_K works and is actually slightly more efficient than the QK_K = 256 version (25.2 ms vs 25.8 ms) * k_quants: WIP super-blocks with 64 weights Q4_K works on Metal and is actually slightly faster than QK_K = 256 (21.95 ms vs 24.0 ms). * k_quants: WIP super-blocks with 64 weights Q2_K works on Metal and is very slightly faster than QK_K = 256 (23.8 ms vs 24.2 ms). * k_quants: WIP super-blocks with 64 weights Q3_K works on Metal and is slightly faster than QK_K = 256 (26.6 ms vs 28.3 ms). * k_quants: WIP super-blocks with 64 weights Q5_K works on Metal and is slightly faster than QK_K = 256 (23.7 ms vs 26.3 ms). * k_quants: call them _K, not _k, also on Metal * k_quants: correctly define QK_K in llama.cpp * Fixed bug in q4_K quantization added with the 64-block addition * Simplify via lambda * k_quants: swicth Q3_K to 4-bit scales when QK_K = 64 Otherwise there isn't much benefit from this quantization type. There is some very slight loss in accuracy, but we reduce size by ~7%. E.g., for OpenLLaMA-3B, Q3_K_S perplexity is 8.6131 with 8-bit scales and 8.6352 with 4-bit, while file size decreases from 1.53G to 1.44G. * k_quants: switch Q4_K to 4-bit scales when QK_K = 64 Here the loss in accuracy is greater than for Q3_K, but the Q4_K points still move further to the left on the perplexity vs size curve. * k_quants: forgot to add the Metal changes in last commit * k_quants: change Q5_K to be type 0 when QK_K = 64 Still needs AVX2 implementation * k_quants: AVX2 implementation for new 64-weight Q5_K * k_quants: 10% faster ARM_NEON Q5_K dot product * k_quants: fixed issue caused by merging with master --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu368
1 files changed, 317 insertions, 51 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 5e2fbc7..c34e96a 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
//================================= k-quants
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
typedef struct {
- uint8_t hmask[QK_K/8];
- uint8_t qs[QK_K/4]; // nibbles / quants
- uint8_t scales[3*QK_K/64];
- half d;
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+#ifdef GGML_QKK_64
+ uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
} block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
+#ifdef GGML_QKK_64
+typedef struct {
+ half d[2]; // super-block scales/mins
+ uint8_t scales[2]; // 4-bit block scales/mins
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
+#ifdef GGML_QKK_64
typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+ half d; // super-block scale
+ int8_t scales[QK_K/16]; // block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
const int i = blockIdx.x;
+ const block_q2_K * x = (const block_q2_K *) vx;
+
const int tid = threadIdx.x;
+#if QK_K == 256
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
- const block_q2_K * x = (const block_q2_K *) vx;
-
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
@@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ float * y = yy + i*QK_K + 16*is + il;
+ float dall = x[i].d;
+ float dmin = x[i].dmin;
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
}
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
- int r = threadIdx.x/4;
- int i = blockIdx.x;
- int tid = r/2;
- int is0 = r%2;
- int l0 = 16*is0 + 4*(threadIdx.x%4);
- int n = tid / 4;
- int j = tid - 4*n;
-
+ const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
+#if QK_K == 256
+ const int r = threadIdx.x/4;
+ const int tid = r/2;
+ const int is0 = r%2;
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
+ const int n = tid / 4;
+ const int j = tid - 4*n;
+
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
@@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+ const int tid = threadIdx.x;
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const int im = il/8; // 0...1
+ const int in = il%8; // 0...7
+
+ float * y = yy + i*QK_K + 16*is + il;
+
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
+ const float d = (float)x[i].d;
+
+ if (is == 0) {
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ } else {
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ }
+#endif
}
+#if QK_K == 256
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
+#endif
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
- //// assume 64 threads - this is very slightly better than the one below
- //const int tid = threadIdx.x;
- //const int il = tid/16;
- //const int ir = tid%16;
- //const int is = 2*il;
- //const int n = 2;
-
+#if QK_K == 256
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
@@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
+#else
+ const int tid = threadIdx.x;
+ const uint8_t * q = x[i].qs;
+ float * y = yy + i*QK_K;
+ const float d = (float)x[i].d[0];
+ const float m = (float)x[i].d[1];
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
+#endif
}
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const int i = blockIdx.x;
+#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int il = tid/16; // il is in 0...3
@@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+ const int tid = threadIdx.x;
+ const uint8_t q = x[i].qs[tid];
+ const int im = tid/8; // 0...3
+ const int in = tid%8; // 0...7
+ const int is = tid/16; // 0 or 1
+ const uint8_t h = x[i].qh[in] >> im;
+ const float d = x[i].d;
+ float * y = yy + i*QK_K + tid;
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
}
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
+#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
@@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+ // assume 32 threads
+ const int tid = threadIdx.x;
+ const int ip = tid/16; // 0 or 1
+ const int il = tid - 16*ip; // 0...15
+
+ float * y = yy + i*QK_K + 16*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t ql = x[i].ql[16*ip + il];
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
+ const int8_t * sc = x[i].scales;
+
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
}
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const block_q2_K * x = (const block_q2_K *)vx + ib0;
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
- float tmp = 0; // partial sum for thread in warp
-
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
tmp += dall * sum1 - dmin * sum2;
}
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION;
+
+ uint32_t uaux[2];
+ const uint8_t * d = (const uint8_t *)uaux;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint32_t * s = (const uint32_t *)x[i].scales;
+
+ uaux[0] = s[0] & 0x0f0f0f0f;
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+ const half2 * dh = (const half2 *)&x[i].d;
+
+ const float2 dall = __half22float2(dh[0]);
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t ql = q[l];
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+ }
+ tmp += dall.x * sum1 - dall.y * sum2;
+ }
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
- const uint16_t kmask1 = 0x0303;
- const uint16_t kmask2 = 0x0f0f;
-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
@@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const block_q3_K * x = (const block_q3_K *)vx + ib0;
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+ const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask2 = 0x0f0f;
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const uint16_t s_shift = 4*im;
- float tmp = 0; // partial sum for thread in warp
-
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
@@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += d * sum;
}
+#else
+
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
+ const int in = offset/8; // 0 or 1
+ const int im = offset%8; // 0...7
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint8_t * s = x[i].scales;
+
+ const float dall = (float)x[i].d;
+
+ float sum = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t hl = x[i].hmask[im+l] >> in;
+ const uint8_t ql = q[l];
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+ }
+ tmp += sum;
+ }
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
- const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
}
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ float tmp = 0;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const float * y = yy + i*QK_K + step;
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+ const float d = (float)x[i].d[0];
+ const float m = (float)x[i].d[1];
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
+ }
+ tmp += sum;
+ }
+
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
- //const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
@@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
- const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
- float tmp = 0; // partial sum for thread in warp
-
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+ }
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+ const int step = tid * K_QUANTS_PER_ITERATION;
+ const int im = step/8;
+ const int in = step%8;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const int8_t * s = x[i].scales;
+ const float * y = yy + i*QK_K + step;
+ const float d = x[i].d;
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ const uint8_t h = x[i].qh[in+j] >> im;
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
+ }
+ tmp += sum;
}
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
@@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
const block_q6_K * x = (const block_q6_K *)vx + ib0;
+#if QK_K == 256
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
@@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
}
+#else
+
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + step;
+ const uint8_t * ql = x[i].ql + step;
+ const uint8_t * qh = x[i].qh + step;
+ const int8_t * s = x[i].scales;
+
+ const float d = x[i+0].d;
+
+ float sum = 0;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
+ }
+ tmp += sum;
+
+ }
+
+#endif
+
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
@@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {