diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2023-07-08 00:25:15 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-08 00:25:15 +0200 |
commit | 061f5f8d2109bb7adcbd40f1b456d887c5a1df25 (patch) | |
tree | b9bd433a271a1e34897d58e004051e09bb60a87a | |
parent | 84525e7962bee0abef91108948bbf7f7bfdcf421 (diff) |
CUDA: add __restrict__ to mul mat vec kernels (#2140)
-rw-r--r-- | ggml-cuda.cu | 53 |
1 files changed, 25 insertions, 28 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7965ff7..ec41e35 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -59,8 +59,8 @@ typedef float2 dfloat2; #endif //GGML_CUDA_DMMV_F16 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); -typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); -typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); +typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); +typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_cuda_op_t)( @@ -131,7 +131,7 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); -typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs); +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs); //================================= k-quants @@ -407,7 +407,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in //================================== k-quants -static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { const int i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; @@ -440,7 +440,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { } -static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; @@ -504,7 +504,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t } #endif -static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; @@ -544,7 +544,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { #endif } -static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; const int i = blockIdx.x; @@ -590,7 +590,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { #endif } -static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; @@ -634,7 +634,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { #endif } -static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { +static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -742,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { +static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -846,7 +846,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { +static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -949,7 +949,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { +static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; @@ -1053,7 +1053,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { +static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); @@ -1171,7 +1171,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -static __global__ void quantize_q8_1(const float * x, void * vy, const int k) { +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -1207,7 +1207,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) { } template <int qk, int qr, dequantize_kernel_t dequantize_kernel> -static __global__ void dequantize_block(const void * vx, float * y, const int k) { +static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; if (i >= k) { @@ -1227,7 +1227,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) y[iybs + iqs + y_offset] = v.y; } -static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; @@ -1252,7 +1252,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons #endif // __CUDA_ARCH__ >= 600 } -static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; @@ -1277,7 +1277,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons #endif // __CUDA_ARCH__ >= 600 } -static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; @@ -1312,7 +1312,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons #endif // __CUDA_ARCH__ >= 600 } -static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; @@ -1346,7 +1346,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons #endif // __CUDA_ARCH__ >= 600 } -static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; @@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons } template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda> -static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) { +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row >= nrows) { @@ -1404,7 +1404,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d } template <int qk, int qr, dequantize_kernel_t dequantize_kernel> -static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { +static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block const int row = blockIdx.y*blockDim.y + threadIdx.y; @@ -1471,7 +1471,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } } -static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { +static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; @@ -1518,7 +1518,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl } static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int channel_stride_x) { const half * x = (const half *) vx; @@ -2355,10 +2355,7 @@ inline void ggml_cuda_op_mul_mat_vec( src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0; - // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6. - // However, they have bad performance with Pascal cards. - // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q. - const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented; + const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented; #endif if (use_mul_mat_vec_q) { |