aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-14 16:36:41 +0300
committerGeorgi Gerganov <ggerganov@gmail.com>2023-07-14 16:36:41 +0300
commit697966680b27d9b4f05668605b863cb9aea3e15f (patch)
treec5c07b2ec21d485c01feb0a704ac996f01cf7af3 /ggml-cuda.cu
parent27ad57a69b85bf12420a27e9945e580cc280be57 (diff)
ggml : sync (ggml_conv_2d, fix mul_mat bug, CUDA GLM rope)
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu54
1 files changed, 52 insertions, 2 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index e0d5e91..920466a 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1667,6 +1667,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
+static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
+ const int col = blockDim.x*blockIdx.x + threadIdx.x;
+ const int half_n_dims = ncols/4;
+
+ if (col >= half_n_dims) {
+ return;
+ }
+
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i = row*ncols + col;
+
+ const float col_theta_scale = powf(theta_scale, col);
+
+ const float theta = p*col_theta_scale;
+ const float sin_theta = sinf(theta);
+ const float cos_theta = cosf(theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + half_n_dims];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
+
+ const float block_theta = block_p*col_theta_scale;
+ const float sin_block_theta = sinf(block_theta);
+ const float cos_block_theta = cosf(block_theta);
+
+ const float x2 = x[i + half_n_dims * 2];
+ const float x3 = x[i + half_n_dims * 3];
+
+ dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
+ dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
+}
+
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -2064,6 +2098,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
}
+static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
+ GGML_ASSERT(nrows % 4 == 0);
+ const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+ const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(num_blocks_x, nrows, 1);
+ rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
+}
+
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -2618,13 +2660,21 @@ inline void ggml_cuda_op_rope(
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
- GGML_ASSERT(mode == 0);
+ const int n_ctx = ((int32_t *) src1->data)[3];
const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+ bool is_glm = mode & 4;
+
// compute
- rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+ if (is_glm) {
+ const float id_p = min(p, n_ctx - 2.f);
+ const float block_p = max(p - (n_ctx - 2.f), 0.f);
+ rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+ } else {
+ rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+ }
(void) dst;
(void) src0_ddq_i;