aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu83
1 files changed, 50 insertions, 33 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 9ebc57a..36a251e 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -515,15 +515,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
const block_q2_K * x = (const block_q2_K *)vx + ib0;
- const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31
- const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 16/K_QUANTS_PER_ITERATION;
- const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
- const int in = tid - step*im; // 0...7
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0...15 or 0...7
- const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
@@ -578,27 +578,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
}
}
-static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
- const int row = blockIdx.x;
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0;
- const int tid = threadIdx.x/2; // 0...15
- const int ix = threadIdx.x%2; // 0, 1
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
- const int n = 2; // iterations in the inner loop
- const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
- const int in = tid - 8*im; // 0...7
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
+ const int step = 16/K_QUANTS_PER_ITERATION;
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0....15 or 0...7
const uint8_t m = 1 << (4*im);
- const int l0 = n*in; // 0...28 in steps of 4
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;
@@ -609,7 +612,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
float tmp = 0; // partial sum for thread in warp
- for (int i = ix; i < num_blocks_per_row; i += 2) {
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
@@ -650,22 +653,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
}
}
-static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
- const int row = blockIdx.x;
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
- const int tid = threadIdx.x/2; // 0...15
- const int ix = threadIdx.x%2;
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
- const int il = tid/4; // 0...3
- const int ir = tid - 4*il;// 0...3
- const int n = 4;
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
+
+ const int il = tid/step; // 0...3
+ const int ir = tid - step*il; // 0...7 or 0...3
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
@@ -681,7 +687,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
float tmp = 0; // partial sum for thread in warp
- for (int i = ix; i < num_blocks_per_row; i += 2) {
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
@@ -736,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
- const int n = 4;
+ const int n = 2;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
@@ -775,11 +781,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
- sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0));
- sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0));
- sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0));
- sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0));
- smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
@@ -1311,7 +1322,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const int ny = 2;
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
@@ -1320,14 +1331,20 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 1, 1);
- dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 1, 1);
- dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {