From e76d630df17e235e6b9ef416c45996765d2e36fb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 23 Jul 2023 15:09:47 +0300 Subject: llama : grouped-query attention + LLaMAv2 70B support (#2276) * CUDA: GQA implementation * llama : support for GQA and LLaMAv2 70B ggml-ci * py : fix hparams parsing (if-else blocks) ggml-ci * py : oh boy .. ggml-ci * help : fix gqa value for 70B ggml-ci --------- Co-authored-by: JohannesGaessler --- ggml-cuda.cu | 71 ++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 26 deletions(-) (limited to 'ggml-cuda.cu') diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c5d157..7204474 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1787,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } } -static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1807,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const } // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -1835,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1857,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -2366,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); } static void ggml_mul_mat_vec_nc_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); } static void ggml_cpy_f32_f32_cuda( @@ -3143,6 +3151,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t ne11 = use_src1 ? src1->ne[1] : 1; const int64_t ne12 = use_src1 ? src1->ne[2] : 1; const int64_t ne13 = use_src1 ? src1->ne[3] : 1; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(ne03 == ne13); const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -3154,12 +3165,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; - const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13; + const int64_t num_iters = flatten_rows ? 1 : num_iters_0; + const int64_t stride_mod = flatten_rows ? num_iters_0 : 1; const int64_t src0_stride = ne00 * ne01 * stride_mod; const int64_t src1_stride = ne10 * ne11 * stride_mod; const int64_t dst_stride = ne0 * ne1 * stride_mod; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12); + const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02; + GGML_ASSERT(!(flatten_rows && ne02 < ne12)); + const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -3176,6 +3194,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 < ne12)); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -3212,7 +3231,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; } else { row_low = 0; - row_high = nrows0; + row_high = nrows0*i02_divisor; } if (row_low == row_high) { continue; @@ -3260,16 +3279,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - const int64_t i03_max = flatten_rows ? 1 : ne03; - const int64_t i02_max = flatten_rows ? 1 : ne02; - const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; - for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; - const int64_t i0 = i03*ne02 + i02; + const int64_t i0 = i03*i02_max + i02; // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs const int64_t i0_offset_low = row_low/rows_per_iter; @@ -3303,10 +3318,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t i11 = i13*ne12 + i12; // for split tensors the data begins at i0 == i0_offset_low - char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; - float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; + char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs; + float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first @@ -3345,11 +3360,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } } - if (!src0_on_device || !src0_is_contiguous) { + if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } } @@ -3503,6 +3518,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; @@ -3515,7 +3532,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -3529,6 +3546,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; @@ -3547,7 +3566,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int row_stride_x = nb01 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { -- cgit v1.2.3