aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2023-06-14 19:47:19 +0200
committerGitHub <noreply@github.com>2023-06-14 19:47:19 +0200
commit254a7a7a5ff4c874ff8488f1f5cbdd7e9c89d682 (patch)
tree65f35a2d189f3cf6f1f625b2acb343c2dd77790d /ggml-cuda.cu
parent92549202659fc23ba9fec5e688227d0da9b06b40 (diff)
CUDA full GPU acceleration, KV cache in VRAM (#1827)
* Fixed CUDA RoPE * ggml_cuda_mul_mat_vec_p021 * ggml_cuda_scale * ggml_cuda_diag_mask_inf * ggml_is_permuted * ggml_cuda_cpy * flatten rows for ggml_cuda_op * Added a --low-vram option * Fixed Windows performance * Fixed LLAMA_CUDA_DMMV_Y > 1 for WizardLM
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu797
1 files changed, 689 insertions, 108 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 3b9a5dd..0565571 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1,5 +1,6 @@
#include <cstddef>
#include <cstdint>
+#include <limits>
#include <stdint.h>
#include <stdio.h>
#include <atomic>
@@ -48,6 +49,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
typedef void (*ggml_cuda_op_t)(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i,
@@ -151,7 +153,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_CPY_BLOCK_SIZE 32
+#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
@@ -655,10 +660,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+ if (row >= nrows) {
+ return;
+ }
+
const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
@@ -703,8 +713,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
}
template <int n_thread, dot_kernel_k_t dot_kernel>
-static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
- const int row = blockIdx.x*blockDim.y + threadIdx.y;
+static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+ if (row >= nrows) {
+ return;
+ }
+
const int tid = threadIdx.x;
const int iter_stride = QK_K;
@@ -737,6 +752,139 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y
}
}
+static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+ const half * x = (half *) vx;
+
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+ const int col_x = col_x0 + threadIdx.x;
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ // x is transposed and permuted
+ const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+ const float xi = __half2float(x[ix]);
+
+ const int row_y = col_x;
+
+
+ // y is not transposed but permuted
+ const int iy = channel*nrows_y + row_y;
+
+ tmp += xi * y[iy];
+ }
+
+ // dst is not transposed and not permuted
+ const int idst = channel*nrows_dst + row_dst;
+
+ // sum up partial sums and write back result
+ __syncthreads();
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ }
+
+ if (threadIdx.x == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+ const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
+
+ const half * x = (half *) vx;
+
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ const int idst = channel*nrows_dst + row_dst;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+ const int col_x = col_x0 + threadIdx.x;
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+ const float xi = __half2float(x[ix]);
+
+ const int row_y = col_x;
+
+ const int iy = channel*nrows_y + row_y;
+
+ tmp += xi * y[iy];
+ }
+
+ // sum up partial sums and write back result
+ __syncthreads();
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ }
+
+ if (threadIdx.x == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+ const float * xi = (float *) cxi;
+ float * dsti = (float *) cdsti;
+
+ *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+ const float * xi = (float *) cxi;
+ half * dsti = (half *) cdsti;
+
+ *dsti = __float2half(*xi);
+}
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+ // then combine those indices with the corresponding byte offsets to get the total offsets
+ const int i02 = i / (ne00*ne01);
+ const int i01 = (i - i02*ne01*ne00) / ne00;
+ const int i00 = i - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+
+ const int i12 = i / (ne10*ne11);
+ const int i11 = (i - i12*ne10*ne11) / ne10;
+ const int i10 = i - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
@@ -758,6 +906,72 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+ const int col = blockDim.x*blockIdx.x + threadIdx.x;
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int i = row*ncols + col;
+ // dst[i] = col > n_past + row ? -INFINITY : x[i];
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+}
+
+// the CUDA soft max implementation differs from the CPU implementation
+// instead of doubles floats are used
+// values are also not normalized to the maximum value by subtracting it in the exponential function
+// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
+static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockDim.y*blockIdx.y + threadIdx.y;
+ const int block_size = blockDim.x;
+ const int tid = threadIdx.x;
+
+ float tmp = 0.0;
+
+ for (int block_start = 0; block_start < ncols; block_start += block_size) {
+ const int col = block_start + tid;
+
+ if (col >= ncols) {
+ break;
+ }
+
+ const int i = row*ncols + col;
+ const float val = expf(x[i]);
+ tmp += val;
+ dst[i] = val;
+ }
+
+ // sum up partial sums
+ __syncthreads();
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ }
+
+ for (int block_start = 0; block_start < ncols; block_start += block_size) {
+ const int col = block_start + tid;
+
+ if (col >= ncols) {
+ break;
+ }
+
+ const int i = row*ncols + col;
+ dst[i] /= tmp;
+ }
+}
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = scale * x[i];
+}
+
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
@@ -831,73 +1045,92 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
- dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 2, 1);
- dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 2, 1);
- dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 2, 1);
- dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
- const dim3 block_dims(32, 2, 1);
- dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ const int ny = 2;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -907,10 +1140,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
+ const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
- <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
@@ -942,6 +1176,47 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}
+static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
+ const dim3 block_nums(1, nrows_x, nchannels_x);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_cuda(
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
+ const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+
+ const dim3 block_nums(1, nrows_x, nchannels_x);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+}
+
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@@ -950,6 +1225,19 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
}
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+ const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, nrows_x, 1);
+ diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(1, nrows_x, 1);
+ soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+}
+
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 256
@@ -1120,10 +1408,25 @@ void ggml_cuda_host_free(void * ptr) {
CUDA_CHECK(cudaFreeHost(ptr));
}
-static cudaError_t ggml_cuda_h2d_tensor_2d(
+static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
- char * dst_char = (char *) dst;
+ cudaMemcpyKind kind;
+ char * src_ptr;
+ if (src->backend == GGML_BACKEND_CPU) {
+ kind = cudaMemcpyHostToDevice;
+ src_ptr = (char *) src->data;
+ } else if (src->backend == GGML_BACKEND_GPU) {
+ kind = cudaMemcpyDeviceToDevice;
+ struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ src_ptr = (char *) extra->data_device[id];
+ } else {
+ GGML_ASSERT(false);
+ }
+ char * dst_ptr = (char *) dst;
+
const int64_t ne0 = src->ne[0];
const int64_t nb0 = src->nb[0];
const int64_t nb1 = src->nb[1];
@@ -1134,17 +1437,17 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
- const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3);
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
- return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
+ return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
} else if (nb0 == ts) {
- return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream);
+ return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
} else {
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
- void * rd = (void *) (dst_char + i1*ts*ne0/bs);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
- cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
if (r != cudaSuccess) return r;
}
return cudaSuccess;
@@ -1380,8 +1683,81 @@ inline void ggml_cuda_op_rope(
(void) i1;
}
+inline void ggml_cuda_op_diag_mask_inf(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ const int n_past = ((int32_t *) src1->data)[0];
+
+ // compute
+ diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
+ CUDA_CHECK(cudaGetLastError());
+
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
+inline void ggml_cuda_op_soft_max(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ // compute
+ soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+ CUDA_CHECK(cudaGetLastError());
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
+inline void ggml_cuda_op_scale(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const float scale = ((float *) src1->data)[0];
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ // compute
+ scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main);
+ CUDA_CHECK(cudaGetLastError());
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
- ggml_cuda_op_t op, bool src0_needs_f32) {
+ ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
@@ -1404,21 +1780,27 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
// strides for iteration over dims 3 and 2
- const int64_t src0_stride = ne00 * ne01;
- const int64_t src1_stride = ne10 * ne11;
- const int64_t dst_stride = ne0 * ne1;
- const int64_t num_iters = ne02 * ne03;
+ const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
+ const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+ const int64_t src0_stride = ne00 * ne01 * stride_mod;
+ const int64_t src1_stride = ne10 * ne11 * stride_mod;
+ const int64_t dst_stride = ne0 * ne1 * stride_mod;
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
- struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
+ const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
+ const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
+ const bool src1_stays_on_host = use_src1 && (
+ dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
+
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@@ -1427,13 +1809,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
- float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+ float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
// asq = actual size quantized, asf = actual size float
size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
- size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+ size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
for (int id = 0; id < g_device_count; ++id) {
if (!split && id != g_main_device) {
@@ -1446,9 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
int64_t row_low, row_high;
if (split) {
row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
- row_low -= row_low % GGML_CUDA_DMMV_Y;
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
- row_high -= row_high % GGML_CUDA_DMMV_Y;
} else {
row_low = 0;
row_high = nrows0;
@@ -1461,7 +1841,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
cudaSetDevice(id);
- if (src0_on_device) {
+ if (src0_on_device && src0_is_contiguous) {
if (src0_is_f32) {
src0_ddf[id] = (float *) src0_extra->data_device[id];
} else {
@@ -1479,8 +1859,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
}
- if (use_src1) {
- if (src1_on_device) {
+ if (use_src1 && !src1_stays_on_host) {
+ if (src1_on_device && src1_is_contiguous) {
src1_ddf[id] = (float *) src1_extra->data_device[id];
} else {
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
@@ -1493,26 +1873,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
}
- for (int64_t i03 = 0; i03 < ne03; i03++) {
+ const int64_t i03_max = flatten_rows ? 1 : ne03;
+ const int64_t i02_max = flatten_rows ? 1 : ne02;
+ const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+
+ for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13;
- for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12;
const int64_t i0 = i03*ne02 + i02;
- const int64_t i0_offset_low = row_low/ne01;
- const int64_t i0_offset_high = row_high/ne01;
+
+ // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
+ const int64_t i0_offset_low = row_low/rows_per_iter;
+ const int64_t i0_offset_high = row_high/rows_per_iter;
int64_t i01_low = 0;
- int64_t i01_high = ne01;
+ int64_t i01_high = rows_per_iter;
if (split) {
if (i0 < i0_offset_low || i0 > i0_offset_high) {
continue;
}
if (i0 == i0_offset_low) {
- i01_low = row_low % ne01;
+ i01_low = row_low % rows_per_iter;
}
if (i0 == i0_offset_high) {
- i01_high = row_high % ne01;
+ i01_high = row_high % rows_per_iter;
}
}
@@ -1521,7 +1907,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
// Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
// The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
GGML_ASSERT(i01_low == 0 || g_device_count > 1);
- GGML_ASSERT(i01_high == ne01 || g_device_count > 1);
+ GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1);
const int64_t i01_diff = i01_high - i01_low;
if (i01_diff == 0) {
@@ -1529,24 +1915,23 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
const int64_t i11 = i13*ne12 + i12;
- cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
- cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
+ cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
// for split tensors the data begins at i0 == i0_offset_low
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
- float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+ float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
// for split tensors the data pointer needs to be rounded down
// to the bin edge for i03, i02 bins beyond the first
if (i0 - i0_offset_low > 0) {
+ GGML_ASSERT(!flatten_rows);
src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
src0_ddf_i -= (row_low % ne01)*ne00;
- }
- if (i0 - i0_offset_low > 0) {
- dst_ddf_i -= (row_low % ne0)*ne1;
+ dst_ddf_i -= (row_low % ne0)*ne1;
}
// the main device memory buffer can be on VRAM scratch, with space for all partial results
@@ -1556,30 +1941,37 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
// copy src0, src1 to device if necessary
- if (use_src1) {
+ if (use_src1 && !src1_stays_on_host) {
if (src1->backend == GGML_BACKEND_CPU) {
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
- } else if (src1->backend == GGML_BACKEND_GPU) {
+ GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
+ int64_t nrows1 = flatten_rows ? nrows0 : ne11;
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
+ } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (id != g_main_device) {
+ GGML_ASSERT(!flatten_rows);
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
src1_ddf_i_source += i11*src1_stride;
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
}
+ } else if (src1_on_device && !src1_is_contiguous) {
+ GGML_ASSERT(!split);
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main));
} else {
GGML_ASSERT(false);
}
}
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
- if (!src0_on_device) {
+
+ if (!src0_on_device || !src0_is_contiguous) {
if (src0_is_f32) {
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
} else {
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
}
}
- // convert src0 to f32 if it's necessary for the ggml_cuda_op
+ // convert src0 to f32 if it is necessary for the ggml_cuda_op
if (src0_needs_f32 && !src0_is_f32) {
to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
CUDA_CHECK(cudaGetLastError());
@@ -1644,39 +2036,30 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
}
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
}
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
}
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
}
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
- GGML_ASSERT(src0->backend != GGML_BACKEND_GPU);
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
- // if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) {
- // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n",
- // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
- // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
- // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
- // return false;
- // }
-
// TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
@@ -1688,23 +2071,158 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
return false;
}
+void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ CUDA_CHECK(cudaSetDevice(g_main_device));
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ void * src0_ddq = src0_extra->data_device[g_main_device];
+
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+
+ CUDA_CHECK(cudaDeviceSynchronize());
+}
+
+void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
+ GGML_ASSERT(!ggml_is_permuted(src0));
+ GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+
+ CUDA_CHECK(cudaSetDevice(g_main_device));
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ void * src0_ddq = src0_extra->data_device[g_main_device];
+
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+ struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+ float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+ const int row_stride_x = nb01 / sizeof(half);
+ const int channel_stride_x = nb02 / sizeof(half);
+
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+
+ CUDA_CHECK(cudaDeviceSynchronize());
+}
+
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- if (src0->type == GGML_TYPE_F32) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
+ bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
+ src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
+
+ if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+ ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
+ } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
+ ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
+ }else if (src0->type == GGML_TYPE_F32) {
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
- if (src1->ne[1] == 1) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
+ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
} else {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
}
} else {
GGML_ASSERT(false);
}
}
+void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true);
+}
+
+void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const int64_t ne = ggml_nelements(src0);
+ GGML_ASSERT(ne == ggml_nelements(src1));
+
+ GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
+ GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
+
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ GGML_ASSERT(src0->ne[3] == 1);
+
+ const int64_t nb00 = src0->nb[0];
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ GGML_ASSERT(src1->ne[3] == 1);
+
+ const int64_t nb10 = src1->nb[0];
+ const int64_t nb11 = src1->nb[1];
+ const int64_t nb12 = src1->nb[2];
+
+ CUDA_CHECK(cudaSetDevice(g_main_device));
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+
+ const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+ char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
+
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+ ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
+ ne10, ne11, nb10, nb11, nb12, cudaStream_main);
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ CUDA_CHECK(cudaDeviceSynchronize());
+
+ (void) dst;
+}
+
+void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
+}
+
+void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true);
+}
+
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
}
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -1718,10 +2236,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
const size_t nb1 = tensor->nb[1];
ggml_backend backend = tensor->backend;
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
+ memset(extra, 0, sizeof(*extra));
for (int id = 0; id < g_device_count; ++id) {
- extra->data_device[id] = nullptr;
-
if (backend == GGML_BACKEND_GPU && id != g_main_device) {
continue;
}
@@ -1734,10 +2251,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
row_high = nrows;
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
- row_low -= row_low % GGML_CUDA_DMMV_Y;
row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
- row_high -= row_high % GGML_CUDA_DMMV_Y;
- GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
} else {
GGML_ASSERT(false);
}
@@ -1781,45 +2295,76 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
delete extra;
}
-void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
- if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) {
- ggml_cuda_assign_buffers(tensor);
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+ if (scratch && g_scratch_size == 0) {
+ return;
}
- const size_t size = ggml_nbytes(tensor);
- GGML_ASSERT(size <= g_scratch_size);
- if (g_scratch_offset + size > g_scratch_size) {
- g_scratch_offset = 0;
+ // recursively assign CUDA buffers until a compute tensor is found
+ if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
+ const ggml_op src0_op = tensor->src0->op;
+ if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
+ ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+ }
+ }
+ if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
+ ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
}
tensor->backend = GGML_BACKEND_GPU;
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
- bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data;
+ const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
+ tensor->op == GGML_OP_VIEW;
+ const size_t size = ggml_nbytes(tensor);
CUDA_CHECK(cudaSetDevice(g_main_device));
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
- extra->data_device[g_main_device] = src0_extra->data_device;
- GGML_ASSERT(false);
- } else {
+ char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+ size_t offset = 0;
+ if (tensor->op == GGML_OP_VIEW) {
+ memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
+ }
+ extra->data_device[g_main_device] = src0_ddc + offset;
+ } else if (tensor->op == GGML_OP_CPY) {
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
+ void * src1_ddv = src1_extra->data_device[g_main_device];
+ extra->data_device[g_main_device] = src1_ddv;
+ } else if (scratch) {
+ GGML_ASSERT(size <= g_scratch_size);
+ if (g_scratch_offset + size > g_scratch_size) {
+ g_scratch_offset = 0;
+ }
+
char * data = (char *) g_scratch_buffer;
if (data == nullptr) {
CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
g_scratch_buffer = data;
}
extra->data_device[g_main_device] = data + g_scratch_offset;
- }
- // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]);
- g_scratch_offset += size;
- // fprintf(stderr, "%s: scratch %d, %p - %p\n",
- // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size);
+ g_scratch_offset += size;
+
+ GGML_ASSERT(g_scratch_offset <= g_scratch_size);
+ } else { // allocate new buffers outside of scratch
+ void * data;
+ CUDA_CHECK(cudaMalloc(&data, size));
+ CUDA_CHECK(cudaMemset(data, 0, size));
+ extra->data_device[g_main_device] = data;
+ }
- GGML_ASSERT(g_scratch_offset <= g_scratch_size);
tensor->extra = extra;
}
+void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
+ ggml_cuda_assign_buffers_impl(tensor, true);
+}
+
+void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
+ ggml_cuda_assign_buffers_impl(tensor, false);
+}
+
void ggml_cuda_set_main_device(int main_device) {
if (main_device > g_device_count) {
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
@@ -1838,6 +2383,15 @@ void ggml_cuda_set_scratch_size(size_t scratch_size) {
g_scratch_size = scratch_size;
}
+void ggml_cuda_free_scratch() {
+ if (g_scratch_buffer == nullptr) {
+ return;
+ }
+
+ CUDA_CHECK(cudaFree(g_scratch_buffer));
+ g_scratch_buffer = nullptr;
+}
+
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
@@ -1875,12 +2429,39 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_mul_mat;
break;
+ case GGML_OP_SCALE:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_scale;
+ break;
+ case GGML_OP_CPY:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_cpy;
+ break;
case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
if (!any_on_device) {
return false;
}
func = ggml_cuda_nop;
break;
+ case GGML_OP_DIAG_MASK_INF:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_diag_mask_inf;
+ break;
+ case GGML_OP_SOFT_MAX:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_soft_max;
+ break;
case GGML_OP_ROPE:
if (!any_on_device) {
return false;