aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-23 15:09:47 +0300
committerGitHub <noreply@github.com>2023-07-23 15:09:47 +0300
commite76d630df17e235e6b9ef416c45996765d2e36fb (patch)
tree15e0e9648f9b0e398b43e888216a73f84098ff3a /ggml-cuda.cu
parent1d0824b2476e7fda09751a0235c9e571b76d6f2c (diff)
llama : grouped-query attention + LLaMAv2 70B support (#2276)
* CUDA: GQA implementation * llama : support for GQA and LLaMAv2 70B ggml-ci * py : fix hparams parsing (if-else blocks) ggml-ci * py : oh boy .. ggml-ci * help : fix gqa value for 70B ggml-ci --------- Co-authored-by: JohannesGaessler <johannesg@5d6.de>
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu71
1 files changed, 45 insertions, 26 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 2c5d157..7204474 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1787,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
}
}
-static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / (nchannels_y / nchannels_x);
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
@@ -1807,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
}
// x is transposed and permuted
- const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+ const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
@@ -1835,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
- const int row_stride_x, const int channel_stride_x) {
+ const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
@@ -1857,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break;
}
- const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
@@ -2366,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}
-static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
- const dim3 block_nums(1, nrows_x, nchannels_x);
+static void ggml_mul_mat_p021_f16_f32_cuda(
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+ const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+ const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
}
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
- const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+ const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
- const dim3 block_nums(1, nrows_x, nchannels_x);
+ const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
- (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
}
static void ggml_cpy_f32_f32_cuda(
@@ -3143,6 +3151,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+ GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
@@ -3154,12 +3165,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
// strides for iteration over dims 3 and 2
- const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
- const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+ const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
+ const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
+ const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
const int64_t src0_stride = ne00 * ne01 * stride_mod;
const int64_t src1_stride = ne10 * ne11 * stride_mod;
const int64_t dst_stride = ne0 * ne1 * stride_mod;
+ const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+ const int64_t i03_max = flatten_rows ? 1 : ne03;
+ const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
+ const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
+ GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
@@ -3176,6 +3194,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+ GGML_ASSERT(!(split && ne02 < ne12));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@@ -3212,7 +3231,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
} else {
row_low = 0;
- row_high = nrows0;
+ row_high = nrows0*i02_divisor;
}
if (row_low == row_high) {
continue;
@@ -3260,16 +3279,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
}
- const int64_t i03_max = flatten_rows ? 1 : ne03;
- const int64_t i02_max = flatten_rows ? 1 : ne02;
- const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-
for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13;
for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12;
- const int64_t i0 = i03*ne02 + i02;
+ const int64_t i0 = i03*i02_max + i02;
// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
const int64_t i0_offset_low = row_low/rows_per_iter;
@@ -3303,10 +3318,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t i11 = i13*ne12 + i12;
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
- float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+ char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+ float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
- float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+ float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
// for split tensors the data pointer needs to be rounded down
// to the bin edge for i03, i02 bins beyond the first
@@ -3345,11 +3360,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
}
- if (!src0_on_device || !src0_is_contiguous) {
+ if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
if (src0_is_f32) {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
} else {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
}
}
@@ -3503,6 +3518,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
+ const int64_t ne12 = src1->ne[2];
+
CUDA_CHECK(cudaSetDevice(g_main_device));
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
@@ -3515,7 +3532,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
- ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
}
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -3529,6 +3546,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
+ const int64_t ne12 = src1->ne[2];
+
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
@@ -3547,7 +3566,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int row_stride_x = nb01 / sizeof(half);
const int channel_stride_x = nb02 / sizeof(half);
- ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {