aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-cuda.cu72
1 files changed, 59 insertions, 13 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index f07bdc7..2c5d157 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -935,12 +935,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
+#if K_QUANTS_PER_ITERATION == 2
+ uint32_t q32[4];
+ const uint8_t * q4 = (const uint8_t *)q32;
+#else
+ uint16_t q16[4];
+ const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
- const uint8_t * q1 = x[i].qs + q_offset;
- const uint8_t * q2 = q1 + 64;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
@@ -953,14 +959,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+#if K_QUANTS_PER_ITERATION == 2
+ const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+ const uint32_t * q2 = q1 + 16;
+
+ q32[0] = q1[0] & 0x0f0f0f0f;
+ q32[1] = q1[0] & 0xf0f0f0f0;
+ q32[2] = q2[0] & 0x0f0f0f0f;
+ q32[3] = q2[0] & 0xf0f0f0f0;
+
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
- for (int l = 0; l < n; ++l) {
- s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
- s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+ for (int l = 0; l < 4; ++l) {
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
+ s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
- tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+ const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+ const uint16_t * q2 = q1 + 32;
+
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[0] & 0xf0f0;
+ q16[2] = q2[0] & 0x0f0f;
+ q16[3] = q2[0] & 0xf0f0;
+
+ float4 s = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ for (int l = 0; l < 2; ++l) {
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+ s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ }
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
}
#else
@@ -1521,7 +1554,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
- const int bq8_offset = QR4_K * (iqs / QI8_1);
+ const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
float sumf_d = 0.0f;
float sumf_m = 0.0f;
@@ -1531,11 +1564,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
- for (int i = 0; i < QR4_K; ++i) {
- const int isc = bq8_offset + i;
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
- uint8_t sc, m;
- get_scale_min_k4(isc, bq4_K->scales, sc, m);
+ for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
@@ -1543,8 +1585,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int vi = (v >> (4*i)) & 0x0F0F0F0F;
- sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
- sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values
+ sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product
+ sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
@@ -2497,7 +2539,9 @@ static size_t g_scratch_offset = 0;
static int g_device_count = -1;
static int g_main_device = 0;
+#ifndef GGML_CUDA_FORCE_DMMV
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
+#endif
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -2520,7 +2564,9 @@ void ggml_init_cublas() {
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
+#ifndef GGML_CUDA_FORCE_DMMV
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+#endif
}
for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;