aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt14
-rw-r--r--Makefile9
-rw-r--r--ggml-cuda.cu368
-rw-r--r--ggml-metal.m66
-rw-r--r--ggml-metal.metal412
-rw-r--r--k_quants.c1140
-rw-r--r--k_quants.h47
-rw-r--r--llama.cpp17
8 files changed, 1876 insertions, 197 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cc7560a..ffda74a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
+option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -225,6 +226,14 @@ if (LLAMA_BLAS)
endif()
endif()
+if (LLAMA_K_QUANTS)
+ set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
+ add_compile_definitions(GGML_USE_K_QUANTS)
+ if (LLAMA_QKK_64)
+ add_compile_definitions(GGML_QKK_64)
+ endif()
+endif()
+
if (LLAMA_CUBLAS)
cmake_minimum_required(VERSION 3.17)
@@ -289,11 +298,6 @@ if (LLAMA_METAL)
)
endif()
-if (LLAMA_K_QUANTS)
- set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
- add_compile_definitions(GGML_USE_K_QUANTS)
-endif()
-
if (LLAMA_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
diff --git a/Makefile b/Makefile
index 5dd676f..bda1179 100644
--- a/Makefile
+++ b/Makefile
@@ -43,8 +43,11 @@ endif
# keep standard at C11 and C++11
# -Ofast tends to produce faster code, but may not be available for some compilers.
-#OPT = -Ofast
+ifdef LLAMA_FAST
+OPT = -Ofast
+else
OPT = -O3
+endif
CFLAGS = -I. $(OPT) -std=c11 -fPIC
CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
LDFLAGS =
@@ -131,6 +134,10 @@ ifndef LLAMA_NO_K_QUANTS
CFLAGS += -DGGML_USE_K_QUANTS
CXXFLAGS += -DGGML_USE_K_QUANTS
OBJS += k_quants.o
+ifdef LLAMA_QKK_64
+ CFLAGS += -DGGML_QKK_64
+ CXXFLAGS += -DGGML_QKK_64
+endif
endif
ifndef LLAMA_NO_ACCELERATE
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 5e2fbc7..c34e96a 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
//================================= k-quants
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
typedef struct {
- uint8_t hmask[QK_K/8];
- uint8_t qs[QK_K/4]; // nibbles / quants
- uint8_t scales[3*QK_K/64];
- half d;
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+#ifdef GGML_QKK_64
+ uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
} block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
+#ifdef GGML_QKK_64
+typedef struct {
+ half d[2]; // super-block scales/mins
+ uint8_t scales[2]; // 4-bit block scales/mins
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
+#ifdef GGML_QKK_64
typedef struct {
- half d; // super-block scale for quantized scales
- half dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+ half d; // super-block scale
+ int8_t scales[QK_K/16]; // block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
const int i = blockIdx.x;
+ const block_q2_K * x = (const block_q2_K *) vx;
+
const int tid = threadIdx.x;
+#if QK_K == 256
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
- const block_q2_K * x = (const block_q2_K *) vx;
-
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
@@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ float * y = yy + i*QK_K + 16*is + il;
+ float dall = x[i].d;
+ float dmin = x[i].dmin;
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
}
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
- int r = threadIdx.x/4;
- int i = blockIdx.x;
- int tid = r/2;
- int is0 = r%2;
- int l0 = 16*is0 + 4*(threadIdx.x%4);
- int n = tid / 4;
- int j = tid - 4*n;
-
+ const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
+#if QK_K == 256
+ const int r = threadIdx.x/4;
+ const int tid = r/2;
+ const int is0 = r%2;
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
+ const int n = tid / 4;
+ const int j = tid - 4*n;
+
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
@@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+ const int tid = threadIdx.x;
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const int im = il/8; // 0...1
+ const int in = il%8; // 0...7
+
+ float * y = yy + i*QK_K + 16*is + il;
+
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
+ const float d = (float)x[i].d;
+
+ if (is == 0) {
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ } else {
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ }
+#endif
}
+#if QK_K == 256
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
+#endif
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
- //// assume 64 threads - this is very slightly better than the one below
- //const int tid = threadIdx.x;
- //const int il = tid/16;
- //const int ir = tid%16;
- //const int is = 2*il;
- //const int n = 2;
-
+#if QK_K == 256
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
@@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
+#else
+ const int tid = threadIdx.x;
+ const uint8_t * q = x[i].qs;
+ float * y = yy + i*QK_K;
+ const float d = (float)x[i].d[0];
+ const float m = (float)x[i].d[1];
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
+#endif
}
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const int i = blockIdx.x;
+#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int il = tid/16; // il is in 0...3
@@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+ const int tid = threadIdx.x;
+ const uint8_t q = x[i].qs[tid];
+ const int im = tid/8; // 0...3
+ const int in = tid%8; // 0...7
+ const int is = tid/16; // 0 or 1
+ const uint8_t h = x[i].qh[in] >> im;
+ const float d = x[i].d;
+ float * y = yy + i*QK_K + tid;
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
}
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
+#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
@@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+ // assume 32 threads
+ const int tid = threadIdx.x;
+ const int ip = tid/16; // 0 or 1
+ const int il = tid - 16*ip; // 0...15
+
+ float * y = yy + i*QK_K + 16*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t ql = x[i].ql[16*ip + il];
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
+ const int8_t * sc = x[i].scales;
+
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
}
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const block_q2_K * x = (const block_q2_K *)vx + ib0;
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
- float tmp = 0; // partial sum for thread in warp
-
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
tmp += dall * sum1 - dmin * sum2;
}
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION;
+
+ uint32_t uaux[2];
+ const uint8_t * d = (const uint8_t *)uaux;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint32_t * s = (const uint32_t *)x[i].scales;
+
+ uaux[0] = s[0] & 0x0f0f0f0f;
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+ const half2 * dh = (const half2 *)&x[i].d;
+
+ const float2 dall = __half22float2(dh[0]);
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t ql = q[l];
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+ }
+ tmp += dall.x * sum1 - dall.y * sum2;
+ }
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
- const uint16_t kmask1 = 0x0303;
- const uint16_t kmask2 = 0x0f0f;
-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
@@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const block_q3_K * x = (const block_q3_K *)vx + ib0;
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+ const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask2 = 0x0f0f;
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
const uint16_t s_shift = 4*im;
- float tmp = 0; // partial sum for thread in warp
-
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
@@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += d * sum;
}
+#else
+
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
+ const int in = offset/8; // 0 or 1
+ const int im = offset%8; // 0...7
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint8_t * s = x[i].scales;
+
+ const float dall = (float)x[i].d;
+
+ float sum = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t hl = x[i].hmask[im+l] >> in;
+ const uint8_t ql = q[l];
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+ }
+ tmp += sum;
+ }
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
@@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
- const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
}
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ float tmp = 0;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const float * y = yy + i*QK_K + step;
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+ const float d = (float)x[i].d[0];
+ const float m = (float)x[i].d[1];
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
+ }
+ tmp += sum;
+ }
+
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
- //const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
@@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
- const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
- float tmp = 0; // partial sum for thread in warp
-
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+ }
+#else
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+ const int step = tid * K_QUANTS_PER_ITERATION;
+ const int im = step/8;
+ const int in = step%8;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const int8_t * s = x[i].scales;
+ const float * y = yy + i*QK_K + step;
+ const float d = x[i].d;
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ const uint8_t h = x[i].qh[in+j] >> im;
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
+ }
+ tmp += sum;
}
+#endif
// sum up partial sums and write back result
__syncthreads();
@@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
- if (tid == 0) {
+ if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
@@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
const block_q6_K * x = (const block_q6_K *)vx + ib0;
+#if QK_K == 256
+
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
@@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
}
+#else
+
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + step;
+ const uint8_t * ql = x[i].ql + step;
+ const uint8_t * qh = x[i].qh + step;
+ const int8_t * s = x[i].scales;
+
+ const float d = x[i+0].d;
+
+ float sum = 0;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
+ }
+ tmp += sum;
+
+ }
+
+#endif
+
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
@@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
+#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
diff --git a/ggml-metal.m b/ggml-metal.m
index a7e104d..7551231 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -51,21 +51,21 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
- GGML_METAL_DECL_KERNEL(get_rows_q2_k);
- GGML_METAL_DECL_KERNEL(get_rows_q3_k);
- GGML_METAL_DECL_KERNEL(get_rows_q4_k);
- GGML_METAL_DECL_KERNEL(get_rows_q5_k);
- GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+ GGML_METAL_DECL_KERNEL(get_rows_q2_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q3_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q4_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q5_K);
+ GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
- GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
exit(1);
}
+#ifdef GGML_QKK_64
+ MTLCompileOptions* options = [MTLCompileOptions new];
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+#else
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+#endif
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
- GGML_METAL_ADD_KERNEL(get_rows_q2_k);
- GGML_METAL_ADD_KERNEL(get_rows_q3_k);
- GGML_METAL_ADD_KERNEL(get_rows_q4_k);
- GGML_METAL_ADD_KERNEL(get_rows_q5_k);
- GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+ GGML_METAL_ADD_KERNEL(get_rows_q2_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q3_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q4_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q5_K);
+ GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
- GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -662,7 +668,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
@@ -671,7 +677,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
@@ -680,7 +686,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
@@ -689,7 +695,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
@@ -698,7 +704,7 @@ void ggml_metal_graph_compute(
nth0 = 4;
nth1 = 16;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
{
@@ -750,11 +756,11 @@ void ggml_metal_graph_compute(
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml-metal.metal b/ggml-metal.metal
index d1e4922..e62fe68 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
}
}
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
}
}
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
//============================================ k-quants ======================================================
+#ifndef QK_K
#define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
-} block_q2_k;
+} block_q2_K;
// 84 bytes / block
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
- half d; // super-block scale
-} block_q3_k;
-// 110 bytes / block
-
+#if QK_K == 64
+ uint8_t scales[2];
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+ half d[2]; // super-block scales/mins
+ uint8_t scales[2];
+ uint8_t qs[QK_K/2]; // 4-bit quants
+} block_q4_K;
+#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
-} block_q4_k;
-// 144 bytes / block
+} block_q4_K;
+#endif
+#if QK_K == 64
+typedef struct {
+ half d; // super-block scales/mins
+ int8_t scales[QK_K/16]; // 8-bit block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+#else
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
-} block_q5_k;
+} block_q5_K;
// 176 bytes / block
+#endif
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
half d; // super-block scale
-} block_q6_k;
+} block_q6_K;
// 210 bytes / block
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
//========================================== dequantization =============================
-static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
device const uint8_t * q = x[i].qs;
+#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
}
q += 32;
}
+#else
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+ for (int l = 0; l < 16; ++l) {
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
+ }
+ y += QK_K;
+#endif
}
}
-static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
+#if QK_K == 256
+
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
}
q += 32;
}
+ }
+#else
+ for (int i = 0; i < nb; i++) {
+ const float d_all = (float)(x[i].d);
+
+ device const uint8_t * q = x[i].qs;
+ device const uint8_t * hm = x[i].hmask;
+
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
+
+ for (int l = 0; l < 8; ++l) {
+ uint8_t h = hm[l];
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+ }
+ y += QK_K;
}
+#endif
}
-static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
-
for (int i = 0; i < nb; i++) {
+ device const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
const float d = x[i].d;
const float min = x[i].dmin;
- device const uint8_t * q = x[i].qs;
device const uint8_t * scales = x[i].scales;
int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
+#else
+ device const uint8_t * s = x[i].scales;
+ device const half2 * dh = (device const half2 *)x[i].d;
+ const float2 d = (float2)dh[0];
+ const float d1 = d[0] * (s[0] & 0xF);
+ const float d2 = d[0] * (s[1] & 0xF);
+ const float m1 = d[1] * (s[0] >> 4);
+ const float m2 = d[1] * (s[1] >> 4);
+ for (int l = 0; l < 32; ++l) {
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+ y[l+32] = d2 * (q[l] >> 4) - m2;
+ }
+ y += QK_K;
+#endif
}
}
-static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
+#if QK_K == 256
for (int i = 0; i < nb; i++) {
const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
u1 <<= 2; u2 <<= 2;
}
}
+#else
+ for (int i = 0; i < nb; i++) {
+
+ const float d = (float)x[i].d;
+
+ device const uint8_t * ql = x[i].qs;
+ device const uint8_t * qh = x[i].qh;
+ device const int8_t * sc = x[i].scales;
+
+ for (int l = 0; l < 8; ++l) {
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
+ }
+ y += QK_K;
+ }
+#endif
}
-static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
const float d = x[i].d;
+#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
qh += 32;
sc += 8;
}
+#else
+ for (int l = 0; l < 16; ++l) {
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ y[l+ 0] = d * sc[0] * q1;
+ y[l+16] = d * sc[1] * q2;
+ y[l+32] = d * sc[2] * q3;
+ y[l+48] = d * sc[3] * q4;
+ }
+ y += 64;
+#endif
}
}
-kernel void kernel_get_rows_q2_k(
+kernel void kernel_get_rows_q2_K(
device const void * src0,
device const int * src1,
device float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
- dequantize_row_q2_k(
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
+ dequantize_row_q2_K(
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
-kernel void kernel_get_rows_q3_k(
+kernel void kernel_get_rows_q3_K(
device const void * src0,
device const int * src1,
device float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
- dequantize_row_q3_k(
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
+ dequantize_row_q3_K(
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
-kernel void kernel_get_rows_q4_k(
+kernel void kernel_get_rows_q4_K(
device const void * src0,
device const int * src1,
device float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
- dequantize_row_q4_k(
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
+ dequantize_row_q4_K(
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
-kernel void kernel_get_rows_q5_k(
+kernel void kernel_get_rows_q5_K(
device const void * src0,
device const int * src1,
device float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
- dequantize_row_q5_k(
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
+ dequantize_row_q5_K(
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
-kernel void kernel_get_rows_q6_k(
+kernel void kernel_get_rows_q6_K(
device const void * src0,
device const int * src1,
device float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
const int i = tpig;
const int r = ((device int32_t *) src1)[i];
- dequantize_row_q6_k(
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
+ dequantize_row_q6_K(
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
(device float *) ((device char *) dst + i*nb1), ne00);
}
//====================================== dot products =========================
-kernel void kernel_mul_mat_q2_k_f32(
+kernel void kernel_mul_mat_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
+ float sumf = 0;
+
+#if QK_K == 256
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid%4; // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
const int y_offset = 64*il + n*ir;
const int q_offset = 32*ip + n*ir;
- sum[ith] = 0.0f;
-
- float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
device const float * y = yy + i*QK_K + y_offset;
- //float4 s = {0.f, 0.f, 0.f, 0.f};
float2 s = {0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
}
- sum[ith] = sumf;
+#else
+ const int il = 4 * tpitg.x;
- //int mask1 = (ith%4 == 0);
- //int mask2 = (ith%16 == 0);
+ uint32_t aux[2];
+ thread const uint8_t * d = (thread const uint8_t *)aux;
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
- //threadgroup_barrier(mem_flags::mem_threadgroup);
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
- //threadgroup_barrier(mem_flags::mem_threadgroup);
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
- //threadgroup_barrier(mem_flags::mem_threadgroup);
- //if (ith == 0) {
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- // dst[r1*ne0 + r0] = sum[0];
- //}
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+ device const uint8_t * q = x[i].qs + il;
+ device const float * y = yy + i*QK_K + il;
+
+ const float dall = (float)x[i].d;
+ const float dmin = (float)x[i].dmin;
+
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
+ aux[0] = a[0] & 0x0f0f0f0f;
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+
+ for (int l = 0; l < 4; ++l) {
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+ }
+ }
+#endif
+
+ sum[ith] = sumf;
//
// Accumulate the sum from all threads in the threadgroup
- // This version is slightly faster than the commented out one below,
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
}
}
-kernel void kernel_mul_mat_q3_k_f32(
+kernel void kernel_mul_mat_q3_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
- const uint16_t kmask1 = 0x0303;
- const uint16_t kmask2 = 0x0f0f;
-
- const uint8_t m3 = 3;
- const int8_t m4 = 4;
-
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
+#if QK_K == 256
+
+ const uint8_t m3 = 3;
+ const int8_t m4 = 4;
+
+ const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask2 = 0x0f0f;
+
const int tid = tpitg.y; // expecting 16
const int ip = tid/8; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
//sum[ith] = sumf;
sum[ith] = sumf1 - 32.f*sumf2;
+#else
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
+ const int im = il/8; // 0, 0, 1, 1
+ const int in = il%8; // 0, 4, 0, 4
+
+ float sumf = 0;
+
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+ const float d_all = (float)(x[i].d);
+
+ device const uint8_t * q = x[i].qs + il;
+ device const uint8_t * h = x[i].hmask + in;
+ device const float * y = yy + i * QK_K + il;
+
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
+
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t hm = h[l] >> im;
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+ }
+
+ }
+
+ sum[ith] = sumf;
+
+#endif
//
// Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
}
-kernel void kernel_mul_mat_q4_k_f32(
+kernel void kernel_mul_mat_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
-
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
+ device const float * yy = (device const float *) src1 + r1*ne10;
+
+ float sumf = 0;
+
+#if QK_K == 256
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
- sum[ith] = 0.0f;
-
uchar2 sc1, sc2, sc3, sc4;
- float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
}
+#else
+ uint16_t aux16[2];
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
+
+ const int il = 4*tpitg.x;
+
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+ device const uint8_t * q = x[i].qs + il;
+ device const float * y = yy + i * QK_K + il;
+
+ const float d = (float)x[i].d[0];
+ const float m = (float)x[i].d[1];
+
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ for (int l = 0; l < 4; ++l) {
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
+ }
+ }
+#endif
sum[ith] = sumf;
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
//}
}
-kernel void kernel_mul_mat_q5_k_f32(
+kernel void kernel_mul_mat_q5_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
- const uint16_t kmask1 = 0x3f3f;
- const uint16_t kmask2 = 0x0f0f;
- const uint16_t kmask3 = 0xc0c0;
-
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
+ float sumf = 0;
+
+#if QK_K == 256
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
uchar2 sc1, sc2, sc3, sc4;
- float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
}
+#else
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
+ const int im = il/8; // 0, 0, 1, 1
+ const int in = il%8; // 0, 4, 0, 4
+
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+ const float d = (float)x[i].d;
+ device const uint8_t * q = x[i].qs + il;
+ device const uint8_t * h = x[i].qh + in;
+ device const int8_t * s = x[i].scales;
+ device const float * y = yy + i*QK_K + il;
+
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t hl = h[l] >> im;
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
+ }
+ }
+#endif
sum[ith] = sumf;
//
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
}
-kernel void kernel_mul_mat_q6_k_f32(
+kernel void kernel_mul_mat_q6_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
+ float sumf = 0;
+
+#if QK_K == 256
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
const int iqs = 16 * tpitg.y;
const int ip = iqs / 128; // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
const int q_offset_l = 64*ip + l0;
const int q_offset_h = 32*ip + l0;
- float sumf = 0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
}
+#else
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
+
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
+ device const float * y = yy + i * QK_K + il;
+ device const uint8_t * ql = x[i].ql + il;
+ device const uint8_t * qh = x[i].qh + il;
+ device const int8_t * s = x[i].scales;
+
+ const float d = x[i].d;
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 4; ++l) {
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
+ }
+
+#endif
sum[ith] = sumf;
diff --git a/k_quants.c b/k_quants.c
index a48c821..46dd884 100644
--- a/k_quants.c
+++ b/k_quants.c
@@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
return scale;
}
+#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
@@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
+#endif
//========================- 2-bit (de)-quantization
@@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
}
}
+#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
+#else
+ for (int l = 0; l < 16; ++l) {
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+ }
+#endif
x += QK_K;
@@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
const uint8_t * q = x[i].qs;
+#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
@@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
}
q += 32;
}
-
+#else
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+ for (int l = 0; l < 16; ++l) {
+ y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
+ y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
+ y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
+ y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
+ }
+ y += QK_K;
+#endif
}
}
@@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
}
}
+#if QK_K == 256
memset(y[i].scales, 0, 12);
if (max_scale) {
float iscale = -32.f/max_scale;
@@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
L[16*j + ii] = l + 4;
}
}
+#else
+ if (max_scale) {
+ float iscale = -8.f/max_scale;
+ for (int j = 0; j < QK_K/16; j+=2) {
+ int l1 = nearest_int(iscale*scales[j]);
+ l1 = 8 + MAX(-8, MIN(7, l1));
+ int l2 = nearest_int(iscale*scales[j+1]);
+ l2 = 8 + MAX(-8, MIN(7, l2));
+ y[i].scales[j/2] = l1 | (l2 << 4);
+ }
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
+ } else {
+ for (int j = 0; j < QK_K/16; j+=2) {
+ y[i].scales[j/2] = 0;
+ }
+ y[i].d = ggml_fp32_to_fp16(0.f);
+ }
+ for (int j = 0; j < QK_K/16; ++j) {
+ int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
+ float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
+ if (!d) {
+ continue;
+ }
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-4, MIN(3, l));
+ L[16*j + ii] = l + 4;
+ }
+ }
+#endif
memset(y[i].hmask, 0, QK_K/8);
- // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
int m = 0;
uint8_t hm = 1;
for (int j = 0; j < QK_K; ++j) {
@@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
m = 0; hm <<= 1;
}
}
+#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
+#else
+ for (int l = 0; l < 16; ++l) {
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+ }
+#endif
x += QK_K;
}
}
+#if QK_K == 256
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
- assert(QK_K == 256);
const int nb = k / QK_K;
const uint32_t kmask1 = 0x03030303;
@@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int
}
}
+#else
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
+ assert(k % QK_K == 0);
+ assert(QK_K == 64);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d_all = ggml_fp16_to_fp32(x[i].d);
+
+ const uint8_t * restrict q = x[i].qs;
+ const uint8_t * restrict hm = x[i].hmask;
+
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
+
+ for (int l=0; l<8; ++l) {
+ uint8_t h = hm[l];
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+ }
+ y += QK_K;
+ }
+}
+#endif
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
quantize_row_q3_K_reference(x, vy, k);
@@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
}
}
+#if QK_K == 256
float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
for (int j = 0; j < QK_K/32; ++j) {
@@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
L[32*j + ii] = l;
}
}
+#else
+ const float s_factor = 15.f;
+ float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
+ float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
+ int d1 = nearest_int(inv_scale*scales[0]);
+ int m1 = nearest_int(inv_min*mins[0]);
+ int d2 = nearest_int(inv_scale*scales[1]);
+ int m2 = nearest_int(inv_min*mins[1]);
+ y[i].scales[0] = d1 | (m1 << 4);
+ y[i].scales[1] = d2 | (m2 << 4);
+ y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
+ y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
+
+ float sumlx = 0;
+ int suml2 = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ const uint8_t sd = y[i].scales[j] & 0xF;
+ const uint8_t sm = y[i].scales[j] >> 4;
+ const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
+ if (!d) continue;
+ const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
+ for (int ii = 0; ii < 32; ++ii) {
+ int l = nearest_int((x[32*j + ii] + m)/d);
+ l = MAX(0, MIN(15, l));
+ L[32*j + ii] = l;
+ sumlx += (x[32*j + ii] + m)*l*sd;
+ suml2 += l*l*sd*sd;
+ }
+ }
+ if (suml2) {
+ y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
+ }
+#endif
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
- for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+ q += 32;
}
x += QK_K;
@@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
for (int i = 0; i < nb; i++) {
- const float d = ggml_fp16_to_fp32(x[i].d);
- const float min = ggml_fp16_to_fp32(x[i].dmin);
-
const uint8_t * q = x[i].qs;
+#if QK_K == 256
+
+ const float d = ggml_fp16_to_fp32(x[i].d);
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
+
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
@@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
+#else
+ const float dall = ggml_fp16_to_fp32(x[i].d[0]);
+ const float mall = ggml_fp16_to_fp32(x[i].d[1]);
+ const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
+ const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
+ for (int l = 0; l < 32; ++l) {
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+ y[l+32] = d2 * (q[l] >> 4) - m2;
+ }
+ y += QK_K;
+#endif
}
}
@@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
assert(k % QK_K == 0);
const int nb = k / QK_K;
+#if QK_K == 256
uint8_t L[QK_K];
float mins[QK_K/32];
float scales[QK_K/32];
+#else
+ int8_t L[QK_K];
+ float scales[QK_K/16];
+#endif
for (int i = 0; i < nb; i++) {
+#if QK_K == 256
+
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
@@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
m1 <<= 2; m2 <<= 2;
ql += 32;
}
+#else
+ float max_scale = 0, amax = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
+ float abs_scale = fabsf(scales[j]);
+ if (abs_scale > amax) {
+ amax = abs_scale;
+ max_scale = scales[j];
+ }
+ }
+
+ float iscale = -128.f/max_scale;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = nearest_int(iscale*scales[j]);
+ y[i].scales[j] = MAX(-128, MIN(127, l));
+ }
+ y[i].d = ggml_fp32_to_fp16(1/iscale);
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
+ if (!d) continue;
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-16, MIN(15, l));
+ L[16*j + ii] = l + 16;
+ }
+ }
+
+ uint8_t * restrict qh = y[i].qh;
+ uint8_t * restrict ql = y[i].qs;
+ memset(qh, 0, QK_K/8);
+
+ for (int j = 0; j < 32; ++j) {
+ int jm = j%8;
+ int is = j/8;
+ int l1 = L[j];
+ if (l1 > 15) {
+ l1 -= 16; qh[jm] |= (1 << is);
+ }
+ int l2 = L[j + 32];
+ if (l2 > 15) {
+ l2 -= 16; qh[jm] |= (1 << (4 + is));
+ }
+ ql[j] = l1 | (l2 << 4);
+ }
+#endif
x += QK_K;
@@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
for (int i = 0; i < nb; i++) {
- const float d = ggml_fp16_to_fp32(x[i].d);
- const float min = ggml_fp16_to_fp32(x[i].dmin);
-
const uint8_t * ql = x[i].qs;
const uint8_t * qh = x[i].qh;
+#if QK_K == 256
+
+ const float d = ggml_fp16_to_fp32(x[i].d);
+ const float min = ggml_fp16_to_fp32(x[i].dmin);
+
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
@@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
+#else
+ float d = ggml_fp16_to_fp32(x[i].d);
+ const int8_t * restrict s = x[i].scales;
+ for (int l = 0; l < 8; ++l) {
+ y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+ y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+ y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+ y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+ y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
+ y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
+ y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
+ y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
+ }
+ y += QK_K;
+#endif
}
}
@@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
+#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
@@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
ql += 64;
qh += 32;
}
+#else
+ for (int l = 0; l < 32; ++l) {
+ const uint8_t q1 = L[l + 0] & 0xF;
+ const uint8_t q2 = L[l + 32] & 0xF;
+ ql[l] = q1 | (q2 << 4);
+ }
+ for (int l = 0; l < 16; ++l) {
+ qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
+ }
+#endif
x += QK_K;
@@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;
+#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
@@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
qh += 32;
sc += 8;
}
+#else
+ for (int l = 0; l < 16; ++l) {
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ y[l+ 0] = d * sc[0] * q1;
+ y[l+16] = d * sc[1] * q2;
+ y[l+32] = d * sc[2] * q3;
+ y[l+48] = d * sc[3] * q4;
+ }
+ y += 64;
+#endif
}
}
@@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
+#if QK_K == 256
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const block_q2_K * restrict x = vx;
@@ -1201,6 +1436,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif
}
+#else
+
+void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+
+ const block_q2_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
+ const int32x4_t vzero = vdupq_n_s32(0);
+
+ int8x16x4_t q2bytes;
+
+ uint32_t aux32[2];
+ const uint8_t * scales = (const uint8_t *)aux32;
+
+ float sum = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * (float)x[i].d;
+ const float dmin = -y[i].d * (float)x[i].dmin;
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+
+ aux32[0] = sc[0] & 0x0f0f0f0f;
+ aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
+
+ sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
+
+ int isum1 = 0, isum2 = 0;
+
+ const uint8x16_t q2bits = vld1q_u8(q2);
+
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
+ q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
+ q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
+ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
+ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+ vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+ vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+ isum1 += vaddvq_s16(p1) * scales[0];
+ isum2 += vaddvq_s16(p2) * scales[1];
+
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+ vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+ const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+ vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+ isum1 += vaddvq_s16(p3) * scales[2];
+ isum2 += vaddvq_s16(p4) * scales[3];
+#endif
+ sum += d * (isum1 + isum2);
+
+ }
+
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m3 = _mm256_set1_epi8(3);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ uint32_t ud, um;
+ const uint8_t * restrict db = (const uint8_t *)&ud;
+ const uint8_t * restrict mb = (const uint8_t *)&um;
+
+ float summs = 0;
+
+ // TODO: optimize this
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
+ ud = (sc[0] >> 0) & 0x0f0f0f0f;
+ um = (sc[0] >> 4) & 0x0f0f0f0f;
+
+ int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
+ summs += dmin * smin;
+
+ const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
+ const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
+ const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+ const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+ const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+
+ const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0));
+ const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1));
+ const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0));
+ const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1));
+
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc);
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc);
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc);
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc);
+ }
+
+ *s = hsum_float_8(acc) + summs;
+
+#else
+
+ float sumf = 0;
+
+ int isum[4];
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * q2 = x[i].qs;
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * sc = x[i].scales;
+
+ int summs = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ summs += y[i].bsums[j] * (sc[j] >> 4);
+ }
+
+ const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
+ const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
+ isum[0] = isum[1] = isum[2] = isum[3] = 0;
+ for (int l = 0; l < 16; ++l) {
+ isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3);
+ isum[1] += q8[l+16] * ((q2[l] >> 2) & 3);
+ isum[2] += q8[l+32] * ((q2[l] >> 4) & 3);
+ isum[3] += q8[l+48] * ((q2[l] >> 6) & 3);
+ }
+ for (int l = 0; l < 4; ++l) {
+ isum[l] *= (sc[l] & 0xF);
+ }
+ sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs;
+ }
+ *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
@@ -1501,6 +1898,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
}
+#else
+
+void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_q3_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+#ifdef __ARM_FEATURE_DOTPROD
+ const int32x4_t vzero = vdupq_n_s32(0);
+#endif
+
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
+ const uint8x16_t mh = vdupq_n_u8(4);
+
+ int8x16x4_t q3bytes;
+
+ uint16_t aux16[2];
+ int8_t * scales = (int8_t *)aux16;
+
+ float sum = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ uint8x16x4_t q3h;
+
+ const uint8x8_t hbits = vld1_u8(x[i].hmask);
+ const uint8x16_t q3bits = vld1q_u8(x[i].qs);
+ const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
+
+ const uint16_t a = *(const uint16_t *)x[i].scales;
+ aux16[0] = a & 0x0f0f;
+ aux16[1] = (a >> 4) & 0x0f0f;
+
+ for (int j = 0; j < 4; ++j) scales[j] -= 8;
+
+ int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
+
+ const float d = y[i].d * (float)x[i].d;
+
+ const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
+ q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
+ q3h.val[1] = vandq_u8(mh, htmp);
+ q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
+ q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
+
+ q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
+ q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
+ q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
+ q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
+ isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
+#else
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+ vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+ vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+ vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+ vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+ isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
+#endif
+
+ sum += d * isum;
+
+ }
+
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m3 = _mm256_set1_epi8(3);
+ const __m256i m1 = _mm256_set1_epi8(1);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ uint64_t aux64;
+
+ uint16_t aux16[2];
+ const int8_t * aux8 = (const int8_t *)aux16;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const uint16_t a = *(const uint16_t *)x[i].scales;
+ aux16[0] = a & 0x0f0f;
+ aux16[1] = (a >> 4) & 0x0f0f;
+
+ const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
+ const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
+
+ memcpy(&aux64, x[i].hmask, 8);
+
+ const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
+ __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
+ __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
+ q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
+ q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
+
+ // load low 2 bits
+ const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
+
+ // prepare low and high bits
+ const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
+ const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
+
+ // load Q8 quants
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+ const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+ // multiply with scales
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
+
+ // multiply with block scale and accumulate
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ int32_t scales[4];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict hm = x[i].hmask;
+ const int8_t * restrict q8 = y[i].qs;
+ int8_t * restrict a = aux8;
+ for (int l = 0; l < 8; ++l) {
+ a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4);
+ a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4);
+ a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4);
+ a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4);
+ a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4);
+ a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4);
+ a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4);
+ a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4);
+ }
+
+ scales[0] = (x[i].scales[0] & 0xF) - 8;
+ scales[1] = (x[i].scales[0] >> 4) - 8;
+ scales[2] = (x[i].scales[1] & 0xF) - 8;
+ scales[3] = (x[i].scales[1] >> 4) - 8;
+
+ memset(aux32, 0, 8*sizeof(int32_t));
+ for (int j = 0; j < QK_K/16; ++j) {
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
+ }
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+
+#endif
+
+}
+#endif
+
+#if QK_K == 256
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
@@ -1614,9 +2211,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
- const uint8_t * restrict q4 = x[i].qs;
- const int8_t * restrict q8 = y[i].qs;
-
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
@@ -1624,6 +2218,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
utmp[2] = uaux;
utmp[0] &= kmask1;
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
@@ -1726,7 +2323,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
*s = sumf;
#endif
}
+#else
+void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_q4_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+#ifdef __ARM_FEATURE_DOTPROD
+ const int32x4_t mzero = vdupq_n_s32(0);
+#endif
+
+ float sumf = 0;
+
+ int8x16x2_t q4bytes;
+ int8x16x4_t q8bytes;
+
+ float sum_mins = 0.f;
+
+ uint16_t aux16[2];
+ const uint8_t * restrict scales = (const uint8_t *)aux16;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ const uint16_t * restrict a = (const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]);
+ sum_mins += y[i].d * (float)x[i].d[1] * summi;
+
+ const float d = y[i].d * (float)x[i].d[0];
+
+ const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
+
+#ifdef __ARM_FEATURE_DOTPROD
+ q8bytes = vld1q_s8_x4(q8);
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
+
+ const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+ const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
+
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+ const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
+ const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
+
+#else
+ q8bytes = vld1q_s8_x4(q8);
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+ int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
+
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
+ vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
+ vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
+ int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
+
+#endif
+ sumf += d * (sumi1 + sumi2);
+
+ }
+
+ *s = sumf - sum_mins;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0;
+
+ uint16_t aux16[2];
+ const uint8_t * scales = (const uint8_t *)aux16;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
+ const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
+ const __m256 vd = _mm256_set1_ps(d);
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+ const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+ const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+
+ const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc);
+
+ const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc);
+
+ }
+
+ *s = hsum_float_8(acc) - summs;
+
+#else
+
+ uint8_t aux8[QK_K];
+ int16_t aux16[16];
+ float sums [8];
+ memset(sums, 0, 8*sizeof(float));
+
+ uint16_t s16[2];
+ const uint8_t * restrict scales = (const uint8_t *)s16;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ uint8_t * restrict a = aux8;
+ for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF;
+ for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4;
+
+ const uint16_t * restrict b = (const uint16_t *)x[i].scales;
+ s16[0] = b[0] & 0x0f0f;
+ s16[1] = (b[0] >> 4) & 0x0f0f;
+
+ sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
+
+ for (int j = 0; j < QK_K/32; ++j) {
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+ q8 += 16; a += 16;
+ for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l];
+ q8 += 16; a += 16;
+ const float dl = d * scales[j];
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]);
+ }
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+#endif
+
+#if QK_K == 256
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
@@ -1840,18 +2606,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) {
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
-
const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
+#if QK_K == 256
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
+
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
+#else
+ // TODO
+ const float d = 0, dmin = 0;
+#endif
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
@@ -1972,8 +2743,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#endif
}
+#else
+
+void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_q5_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const int32x4_t mzero = vdupq_n_s32(0);
+ const uint8x16_t mh = vdupq_n_u8(16);
+
+ int8x16x4_t q5bytes;
+ uint8x16x4_t q5h;
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * (float)x[i].d;
+ const int8_t * sc = x[i].scales;
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const uint8x8_t qhbits = vld1_u8(qh);
+
+ const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
+ const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+ const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
+ q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
+ q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2));
+ q5h.val[2] = vbicq_u8(mh, htmp);
+ q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2));
+
+ q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0]));
+ q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1]));
+ q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
+ q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+ int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
+ int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
+ int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
+ int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
+
+ sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
+
+#else
+
+ const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+ vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+ vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+ int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
+ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+ vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+ const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+ vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+ sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
+ sumf += d*sumi;
+#endif
+
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+ const __m256i mone = _mm256_set1_epi8(1);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
+
+ const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
+ const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
+
+ int64_t aux64;
+ memcpy(&aux64, x[i].qh, 8);
+ const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
+ const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
+
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
+
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+ const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0));
+ const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1));
+ const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0));
+ const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1));
+
+ const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1));
+
+ acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+
+
+ uint8_t aux8[QK_K];
+ int16_t aux16[16];
+ float sums [8];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].qs;
+ const uint8_t * restrict hm = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+ uint8_t * restrict a = aux8;
+ for (int l = 0; l < 32; ++l) {
+ a[l+ 0] = q4[l] & 0xF;
+ a[l+32] = q4[l] >> 4;
+ }
+ for (int is = 0; is < 8; ++is) {
+ uint8_t m = 1 << is;
+ for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
+ }
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+ const int8_t * restrict sc = x[i].scales;
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ const float dl = d * sc[j];
+ for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]);
+ q8 += 16; a += 16;
+ }
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+#endif
+
+
+#if QK_K == 256
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0);
@@ -2242,3 +3174,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
*s = sumf;
#endif
}
+
+#else
+
+void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ assert(n % QK_K == 0);
+
+ const block_q6_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+ float sum = 0;
+
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
+ const int32x4_t vzero = vdupq_n_s32(0);
+ const int8x16_t m32s = vdupq_n_s8(32);
+
+ const uint8x16_t mone = vdupq_n_u8(3);
+
+ int8x16x4_t q6bytes;
+ uint8x16x4_t q6h;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d_all = (float)x[i].d;
+
+ const uint8_t * restrict q6 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const int8_t * restrict scale = x[i].scales;
+
+ int32_t isum = 0;
+
+ uint8x16_t qhbits = vld1q_u8(qh);
+ uint8x16x2_t q6bits = vld1q_u8_x2(q6);
+ int8x16x4_t q8bytes = vld1q_s8_x4(q8);
+
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
+ uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits, 4);
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits, 6);
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+ q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+ q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+ q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
+ q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+ isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+ vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+#else
+
+ int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
+ vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
+ int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
+ vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
+ isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
+
+ int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
+ vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
+ int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
+ vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
+ isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
+#endif
+
+ sum += isum * d_all * y[i].d;
+
+ }
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+ const __m256i m2 = _mm256_set1_epi8(3);
+ const __m256i m32s = _mm256_set1_epi8(32);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
+
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
+ const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
+ const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
+ const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
+
+ __m256i sumi = _mm256_setzero_si256();
+
+ const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
+ const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
+
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
+ const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
+
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
+
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
+
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
+ int8_t * restrict a = aux8;
+ for (int l = 0; l < 16; ++l) {
+ a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ }
+ int is = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int scale = x[i].scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ }
+ const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+
+#endif
diff --git a/k_quants.h b/k_quants.h
index 10a0baa..6abe3d7 100644
--- a/k_quants.h
+++ b/k_quants.h
@@ -7,7 +7,13 @@
#include <stddef.h>
// Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
//
// Super-block quantization structures
@@ -29,38 +35,67 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
// weight is represented as x = a * q
// 16 blocks of 16 elemenets each
// Effectively 3.4375 bits per weight
+#ifdef GGML_QKK_64
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
+ uint8_t scales[2];
ggml_fp16_t d; // super-block scale
} block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
+#else
+typedef struct {
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+ uint8_t scales[12]; // scales, quantized with 6 bits
+ ggml_fp16_t d; // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+#endif
// 4-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+ ggml_fp16_t d[2]; // super-block scales/mins
+ uint8_t scales[2]; // 4-bit block scales/mins
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
-static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+#endif
// 5-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
+#ifdef GGML_QKK_64
+typedef struct {
+ ggml_fp16_t d; // super-block scale
+ int8_t scales[QK_K/16]; // 8-bit block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
// 6-bit quantization
// weight is represented as x = a * q
diff --git a/llama.cpp b/llama.cpp
index ac22a48..c41c2a8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -21,9 +21,13 @@
#endif
#ifdef GGML_USE_K_QUANTS
#ifndef QK_K
+#ifdef GGML_QKK_64
+#define QK_K 64
+#else
#define QK_K 256
#endif
#endif
+#endif
#include <array>
#include <ctime>
@@ -2470,6 +2474,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
std::vector<std::thread> workers;
std::mutex mutex;
+ auto use_more_bits = [] (int i_layer, int num_layers) -> bool {
+ return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2;
+ };
+
size_t idx = 0;
for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) {
llama_buffer read_data;
@@ -2524,15 +2532,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
- (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8 ||
- (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K;
+ use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
+ else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
+ (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
++i_attention_wv;
} else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
- (i_feed_forward_w2 < n_feed_forward_w2/8 || i_feed_forward_w2 >= 7*n_feed_forward_w2/8 ||
- (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K;
+ use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
+ //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
++i_feed_forward_w2;
} else if (tensor.name.find("attention.wo.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;