diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 110 |
1 files changed, 71 insertions, 39 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35d2e45..98170a3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,9 +83,19 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define WARP_SIZE 32 + #define CUDA_MUL_BLOCK_SIZE 256 + #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_DMMV_Y +#define GGML_CUDA_DMMV_Y 1 +#endif static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel> +template <int qk, int qr, dequantize_kernel_t dequantize_kernel> static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { - const int row = blockIdx.x; + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + float tmp = 0; // partial sum for thread in warp - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index - // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs, v0, v1); +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + // matrix multiplication + tmp += v0 * y[iybs + iqs + j/qr + 0]; + tmp += v1 * y[iybs + iqs + j/qr + y_offset]; + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 + } } // sum up partial sums and write back result __syncthreads(); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } + if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = tmp; } } @@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16> - <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { |