From 16b9cd193965769089881bb8ec012fccca7b37b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 19 Jun 2023 10:23:56 +0200 Subject: Convert vector to f16 for dequantize mul mat vec (#1913) * Convert vector to f16 for dmmv * compile option * Added compilation option description to README * Changed cmake CUDA_ARCHITECTURES from "OFF" to "native" --- ggml-cuda.cu | 202 ++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 139 insertions(+), 63 deletions(-) (limited to 'ggml-cuda.cu') diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 16488b9..9ebc57a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -50,7 +50,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); +#ifdef GGML_CUDA_DMMV_F16 +typedef half dfloat; // dequantize float +typedef half2 dfloat2; +#else +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +#endif //GGML_CUDA_DMMV_F16 + +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -234,82 +242,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } -static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; - const uint8_t vui = x[ib].qs[iqs]; + const int vui = x[ib].qs[iqs]; - const int8_t vi0 = vui & 0xF; - const int8_t vi1 = vui >> 4; + v.x = vui & 0xF; + v.y = vui >> 4; - v0 = (vi0 - 8)*d; - v1 = (vi1 - 8)*d; +#ifdef GGML_CUDA_DMMV_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const float d = x[ib].d; - const float m = x[ib].m; + const dfloat d = x[ib].d; + const dfloat m = x[ib].m; - const uint8_t vui = x[ib].qs[iqs]; + const int vui = x[ib].qs[iqs]; - const int8_t vi0 = vui & 0xF; - const int8_t vi1 = vui >> 4; + v.x = vui & 0xF; + v.y = vui >> 4; - v0 = vi0*d + m; - v1 = vi1*d + m; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); - const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; - const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v0 = x0*d; - v1 = x1*d; +#ifdef GGML_CUDA_DMMV_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const float d = x[ib].d; - const float m = x[ib].m; + const dfloat d = x[ib].d; + const dfloat m = x[ib].m; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); - const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); - const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - v0 = x0*d + m; - v1 = x1*d + m; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_DMMV_F16 } -static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; - const float d = x[ib].d; + const dfloat d = x[ib].d; - const int8_t vi0 = x[ib].qs[iqs + 0]; - const int8_t vi1 = x[ib].qs[iqs + 1]; + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; - v0 = vi0*d; - v1 = vi1*d; +#ifdef GGML_CUDA_DMMV_F16 + v = __hmul2(v, {d, d}); +#else + v.x *= d; + v.y *= d; +#endif // GGML_CUDA_DMMV_F16 } //================================== k-quants @@ -843,11 +875,12 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; - v0 = __half2float(x[ib + iqs + 0]); - v1 = __half2float(x[ib + iqs + 1]); + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; } template @@ -864,13 +897,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) const int y_offset = qr == 1 ? 1 : qk/2; // dequantize - float & v0 = y[iybs + iqs + 0]; - float & v1 = y[iybs + iqs + y_offset]; - dequantize_kernel(vx, ib, iqs, v0, v1); + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; } template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { +static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block const int row = blockIdx.y*blockDim.y + threadIdx.y; @@ -885,7 +920,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - float tmp = 0.0f; // partial sum for thread in warp +// partial sum for each thread +#ifdef GGML_CUDA_DMMV_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_DMMV_F16 for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; @@ -899,14 +939,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, // process 2 vals per j iter // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); // matrix multiplication - tmp += v0 * y[iybs + iqs + j/qr + 0]; - tmp += v1 * y[iybs + iqs + j/qr + y_offset]; // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_DMMV_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; +#endif // GGML_CUDA_DMMV_F16 } } @@ -918,7 +965,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } if (tid == 0) { +#ifdef GGML_CUDA_DMMV_F16 + dst[row] = tmp.x + tmp.y; +#else dst[row] = tmp; +#endif // GGML_CUDA_DMMV_F16 } } @@ -1213,7 +1264,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu dequantize_block_q6_K<<>>(vx, y); } -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1222,7 +1273,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1231,7 +1282,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1240,7 +1291,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1249,7 +1300,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1299,7 +1350,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c dequantize_block<1, 1, convert_f16><<>>(vx, y, k); } -static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); @@ -1714,21 +1765,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( const int64_t ne00 = src0->ne[0]; const int64_t nrows = i01_high - i01_low; +// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_CUDA_DMMV_F16 + size_t ash; + dfloat * src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); + ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00, + ne00, 1, sizeof(float), 0, 0, + ne00, 1, sizeof(half), 0, 0, cudaStream_main); + } +#else + dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion +#endif // GGML_CUDA_DMMV_F16 + switch (src0->type) { case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); @@ -1746,7 +1816,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); break; case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); break; default: GGML_ASSERT(false); @@ -1754,6 +1824,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( } CUDA_CHECK(cudaGetLastError()); +#ifdef GGML_CUDA_DMMV_F16 + if (src1_convert_f16) { + ggml_cuda_pool_free(src1_dfloat, ash); + } +#endif // GGML_CUDA_DMMV_F16 + (void) src1; (void) dst; (void) src0_ddf_i; -- cgit v1.2.3