aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorSIGSEGV <21287366+akr2002@users.noreply.github.com>2023-07-12 19:18:43 +0530
committerGitHub <noreply@github.com>2023-07-12 19:18:43 +0530
commit2516af4cd61f509c995b4f78fdf123cba33f3509 (patch)
treede7324f01b9454fb30e4d827b8300d02fd982ed3 /ggml-cuda.cu
parentff34a7d385fc47c4d432fd8c19306d5aca814d05 (diff)
parent4e7464ef88885cb3532738b03cac890f4077fa20 (diff)
Merge branch 'ggerganov:master' into master
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu150
1 files changed, 111 insertions, 39 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index fd36f17..89e69bd 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ const float eps = 1e-5f;
+
+ float mean = 0.0f;
+ float var = 0.0f;
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ const float xi = x[row*ncols + col];
+ mean += xi;
+ var += xi * xi;
+ }
+
+ // sum up partial sums
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
+ var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+ }
+
+ mean /= ncols;
+ var = var / ncols - mean * mean;
+ const float inv_var = rsqrtf(var + eps);
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+ }
+}
+
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-6;
+ const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
const float mean = tmp / ncols;
- const float scale = 1.0f / sqrtf(mean + eps);
+ const float scale = rsqrtf(mean + eps);
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
@@ -1229,7 +1258,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
int vi;
@@ -1250,11 +1279,11 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1275,11 +1304,11 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
int qs;
@@ -1310,11 +1339,11 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1344,11 +1373,11 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
int vi;
@@ -1363,7 +1392,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2237,16 +2272,21 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ // TODO: support broadcasting
+ GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
- const int64_t ne0 = src0->ne[0];
+ const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
+ const int64_t ne10 = src1->ne[0];
+
// compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else {
GGML_ASSERT(false);
}
@@ -2265,10 +2305,9 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
-
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@@ -2277,7 +2316,7 @@ inline void ggml_cuda_op_mul(
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
- float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
+ float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
// compute
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
@@ -2310,6 +2349,28 @@ inline void ggml_cuda_op_silu(
(void) i1;
}
+inline void ggml_cuda_op_norm(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ // compute
+ norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2356,7 +2417,7 @@ inline void ggml_cuda_op_mul_mat_vec(
src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0;
- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
#endif
if (use_mul_mat_vec_q) {
@@ -2930,6 +2991,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
}
+void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+}
+
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3226,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
- cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+ CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf;
@@ -3200,36 +3266,36 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
}
// recursively assign CUDA buffers until a compute tensor is found
- if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
- const ggml_op src0_op = tensor->src0->op;
+ if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
+ const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
- ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
}
}
- if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
- ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
+ if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
+ ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
}
tensor->backend = GGML_BACKEND_GPU;
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
- const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
+ const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW ||
force_inplace;
const size_t size = ggml_nbytes(tensor);
CUDA_CHECK(cudaSetDevice(g_main_device));
- if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
+ if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t offset = 0;
if (tensor->op == GGML_OP_VIEW) {
- memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
+ memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
}
extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) {
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
void * src1_ddv = src1_extra->data_device[g_main_device];
extra->data_device[g_main_device] = src1_ddv;
} else if (scratch) {
@@ -3300,8 +3366,8 @@ void ggml_cuda_free_scratch() {
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
- || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
- || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
+ || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
+ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
switch (tensor->op) {
case GGML_OP_ADD:
@@ -3322,6 +3388,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_silu;
break;
+ case GGML_OP_NORM:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_norm;
+ break;
case GGML_OP_RMS_NORM:
if (!any_on_device) {
return false;
@@ -3329,7 +3401,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
func = ggml_cuda_rms_norm;
break;
case GGML_OP_MUL_MAT:
- if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
+ if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
return false;
}
func = ggml_cuda_mul_mat;
@@ -3383,6 +3455,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return true;
}
- func(tensor->src0, tensor->src1, tensor);
+ func(tensor->src[0], tensor->src[1], tensor);
return true;
}