aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-11 22:53:34 +0300
committerGitHub <noreply@github.com>2023-07-11 22:53:34 +0300
commit20d7740a9b45f6e5b247fa3738fdda35e18c2e8a (patch)
treec64588319bc093347062551e2871100d853e3861 /ggml-cuda.cu
parent5bf2a2771886ee86137e01dbc7492f78fb392066 (diff)
ggml : sync (abort callback, mul / add broadcast, fix alibi) (#2183)
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu115
1 files changed, 88 insertions, 27 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 1673e7e..2fb30c6 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
};
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
- if (i >= k) {
+ if (i >= kx) {
return;
}
- dst[i] = x[i] + y[i];
+ dst[i] = x[i] + y[i%ky];
}
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ const float eps = 1e-5f;
+
+ float mean = 0.0f;
+ float var = 0.0f;
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ const float xi = x[row*ncols + col];
+ mean += xi;
+ var += xi * xi;
+ }
+
+ // sum up partial sums
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
+ var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+ }
+
+ mean /= ncols;
+ var = var / ncols - mean * mean;
+ const float inv_var = rsqrtf(var + eps);
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+ }
+}
+
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-6;
+ const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
const float mean = tmp / ncols;
- const float scale = 1.0f / sqrtf(mean + eps);
+ const float scale = rsqrtf(mean + eps);
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
@@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
- const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
- add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+ const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+ add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
}
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);
- const int64_t ne0 = src0->ne[0];
+ const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
+ const int64_t ne10 = src1->ne[0];
+
// compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else {
GGML_ASSERT(false);
}
@@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
-
- for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
- const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
- float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
- float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
- float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
-
- // compute
- mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
- }
+ mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
(void) dst;
(void) src0_ddq_i;
@@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
(void) i1;
}
+inline void ggml_cuda_op_norm(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ // compute
+ norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
}
+void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+}
+
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
- cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+ CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf;
@@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_silu;
break;
+ case GGML_OP_NORM:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_norm;
+ break;
case GGML_OP_RMS_NORM:
if (!any_on_device) {
return false;