aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/common.cpp25
-rw-r--r--examples/common.h11
-rw-r--r--ggml-cuda.cu287
-rw-r--r--ggml-cuda.h2
-rw-r--r--ggml.c1
-rw-r--r--ggml.h8
-rw-r--r--llama.cpp37
-rw-r--r--llama.h7
8 files changed, 336 insertions, 42 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 80e35d2..86c1eef 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "--mlock") {
params.use_mlock = true;
+ } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_gpu_layers = std::stoi(argv[i]);
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--mtest") {
@@ -421,6 +427,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
if (llama_mmap_supported()) {
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
+ fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
+ fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@@ -463,14 +471,15 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
- lparams.n_ctx = params.n_ctx;
- lparams.n_parts = params.n_parts;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.logits_all = params.perplexity;
- lparams.embedding = params.embedding;
+ lparams.n_ctx = params.n_ctx;
+ lparams.n_parts = params.n_parts;
+ lparams.n_gpu_layers = params.n_gpu_layers;
+ lparams.seed = params.seed;
+ lparams.f16_kv = params.memory_f16;
+ lparams.use_mmap = params.use_mmap;
+ lparams.use_mlock = params.use_mlock;
+ lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
diff --git a/examples/common.h b/examples/common.h
index 499671b..717838f 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -21,13 +21,14 @@
int32_t get_num_physical_cores();
struct gpt_params {
- int32_t seed = -1; // RNG seed
+ int32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_predict = -1; // new tokens to predict
- int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
- int32_t n_ctx = 512; // context size
- int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_gpu_layers = 0; // number of layers to store in VRAM
// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 8a3beb0..b6a7754 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -32,9 +32,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
+// QK = number of values after dequantization
+// QR = QK / number of values before dequantization
#define QK4_0 32
+#define QR4_0 2
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
@@ -42,6 +48,7 @@ typedef struct {
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
+#define QR4_1 2
typedef struct {
float d; // delta
float m; // min
@@ -50,6 +57,7 @@ typedef struct {
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
+#define QR5_0 2
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
@@ -58,6 +66,7 @@ typedef struct {
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
+#define QR5_1 2
typedef struct {
half d; // delta
half m; // min
@@ -67,12 +76,100 @@ typedef struct {
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
+#define QR8_0 1
typedef struct {
float d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+#define CUDA_DMMV_BLOCK_SIZE 32
+
+static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const float d = x[ib].d;
+
+ const uint8_t vui = x[ib].qs[iqs];
+
+ const int8_t vi0 = vui & 0xF;
+ const int8_t vi1 = vui >> 4;
+
+ v0 = (vi0 - 8)*d;
+ v1 = (vi1 - 8)*d;
+}
+
+static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const float d = x[ib].d;
+ const float m = x[ib].m;
+
+ const uint8_t vui = x[ib].qs[iqs];
+
+ const int8_t vi0 = vui & 0xF;
+ const int8_t vi1 = vui >> 4;
+
+ v0 = vi0*d + m;
+ v1 = vi1*d + m;
+}
+
+static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const float d = x[ib].d;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
+
+ v0 = x0*d;
+ v1 = x1*d;
+}
+
+static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const float d = x[ib].d;
+ const float m = x[ib].m;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+ v0 = x0*d + m;
+ v1 = x1*d + m;
+}
+
+static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const float d = x[ib].d;
+
+ const int8_t vi0 = x[ib].qs[iqs + 0];
+ const int8_t vi1 = x[ib].qs[iqs + 1];
+
+ v0 = vi0*d;
+ v1 = vi1*d;
+}
+
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+ const half * x = (const half *) vx;
+
+ v0 = __half2float(x[ib + 0]);
+ v1 = __half2float(x[ib + 1]);
+}
+
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static const int qk = QK4_0;
@@ -173,6 +270,44 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}
+template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
+ const int row = blockIdx.x;
+ const int tid = threadIdx.x;
+
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ __shared__ float tmp[block_size]; // separate sum for each thread
+ tmp[tid] = 0;
+
+ for (int i = 0; i < ncols/block_size; i += 2) {
+ const int col = i*block_size + 2*tid;
+ const int ib = (row*ncols + col)/qk; // block index
+ const int iqs = (col%qk)/qr; // quant index
+ const int iybs = col - col%qk; // y block start index
+
+ // dequantize
+ float v0, v1;
+ dequantize_kernel(vx, ib, iqs, v0, v1);
+
+ // matrix multiplication
+ tmp[tid] += v0 * y[iybs + iqs + 0];
+ tmp[tid] += v1 * y[iybs + iqs + y_offset];
+ }
+
+ // sum up partial sums and write back result
+ __syncthreads();
+ for (int s=block_size/2; s>0; s>>=1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ dst[row] = tmp[0];
+ }
+}
+
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -198,6 +333,36 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
@@ -211,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
}
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
+}
+
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
@@ -230,8 +401,27 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}
+static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_mul_mat_vec_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_mul_mat_vec_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_mul_mat_vec_q5_0_cuda;
+ case GGML_TYPE_Q5_1:
+ return dequantize_mul_mat_vec_q5_1_cuda;
+ case GGML_TYPE_Q8_0:
+ return dequantize_mul_mat_vec_q8_0_cuda;
+ case GGML_TYPE_F16:
+ return dequantize_mul_mat_vec_q8_0_cuda;
+ default:
+ return nullptr;
+ }
+}
+
// buffer pool for cuda
-#define MAX_CUDA_BUFFERS 16
+#define MAX_CUDA_BUFFERS 256
struct scoped_spin_lock {
std::atomic_flag& lock;
@@ -528,6 +718,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
const ggml_type type = src0->type;
+ const bool mul_mat_vec = ne11 == 1;
const float alpha = 1.0f;
const float beta = 0.0f;
@@ -538,12 +729,16 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
size_t x_size, y_size, d_size, q_size;
- float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+ float * d_X = nullptr;
+ if (!mul_mat_vec) {
+ d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+ }
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
+ dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
GGML_ASSERT(to_fp32_cuda != nullptr);
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -553,31 +748,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
- float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;
- // copy src0 and convert to fp32 on device
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
- to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
- CUDA_CHECK(cudaGetLastError());
- CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+ // copy src0 to device if necessary
+ if (src0->backend == GGML_BACKEND_CPU) {
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
+ } else if (src0->backend == GGML_BACKEND_CUDA) {
+ c_Q = ((char *) src0->data) + i * q_sz;
+ } else {
+ GGML_ASSERT(false);
+ }
+ if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
+ CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
- // copy src1 to device
- CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+ // copy src1 to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
- // wait for conversion
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+ // wait for data
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
- // compute
- CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
- CUBLAS_CHECK(
- cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- &alpha, c_X, ne00,
- c_Y, ne10,
- &beta, c_D, ne01));
+ // compute
+ dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
+ CUDA_CHECK(cudaGetLastError());
+
+ } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
+ float * c_X = d_X + i * x_ne;
+
+ // convert src0 to fp32 on device
+ to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
+ CUDA_CHECK(cudaGetLastError());
+ CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+ // copy src1 to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+
+ // wait for conversion
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+ // compute
+ CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
+ CUBLAS_CHECK(
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ &alpha, c_X, ne00,
+ c_Y, ne10,
+ &beta, c_D, ne01));
+ }
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@@ -586,7 +804,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
}
CUDA_CHECK(cudaDeviceSynchronize());
- ggml_cuda_pool_free(d_X, x_size);
+ if (!mul_mat_vec) {
+ ggml_cuda_pool_free(d_X, x_size);
+ }
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
@@ -602,8 +822,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
+ ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
return true;
}
@@ -655,3 +874,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
return 0;
}
}
+
+void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
+ const int64_t ne0 = tensor->ne[0];
+ const int64_t ne1 = tensor->ne[1];
+ const int64_t ne2 = tensor->ne[2];
+ const int64_t ne3 = tensor->ne[3];
+
+ const ggml_type type = tensor->type;
+ const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+
+ size_t q_size;
+ char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+
+ cudaStream_t cudaStream2 = g_cudaStreams2[0];
+
+ // copy tensor to device
+ CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
+ CUDA_CHECK(cudaDeviceSynchronize());
+
+ tensor->data = d_Q;
+ tensor->backend = GGML_BACKEND_CUDA;
+}
diff --git a/ggml-cuda.h b/ggml-cuda.h
index f7d6a8b..4e2c242 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
+void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
+
#ifdef __cplusplus
}
#endif
diff --git a/ggml.c b/ggml.c
index 675eb0d..0574638 100644
--- a/ggml.c
+++ b/ggml.c
@@ -3882,6 +3882,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
*result = (struct ggml_tensor) {
/*.type =*/ type,
+ /*.backend =*/ GGML_BACKEND_CPU,
/*.n_dims =*/ n_dims,
/*.ne =*/ { 1, 1, 1, 1 },
/*.nb =*/ { 0, 0, 0, 0 },
diff --git a/ggml.h b/ggml.h
index 2745fb3..967ef72 100644
--- a/ggml.h
+++ b/ggml.h
@@ -243,6 +243,11 @@ extern "C" {
GGML_TYPE_COUNT,
};
+ enum ggml_backend {
+ GGML_BACKEND_CPU = 0,
+ GGML_BACKEND_CUDA = 1,
+ };
+
// model file types
enum ggml_ftype {
GGML_FTYPE_UNKNOWN = -1,
@@ -333,6 +338,7 @@ extern "C" {
// n-dimensional tensor
struct ggml_tensor {
enum ggml_type type;
+ enum ggml_backend backend;
int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -363,7 +369,7 @@ extern "C" {
char name[32];
- char padding[8]; // TODO: remove and add padding to name?
+ char padding[9]; // TODO: remove and add padding to name?
};
// computation graph
diff --git a/llama.cpp b/llama.cpp
index 08c7352..73b932a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -9,6 +9,9 @@
#include "llama.h"
#include "ggml.h"
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#endif
#include <array>
#include <ctime>
@@ -810,6 +813,7 @@ struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.n_ctx =*/ 512,
/*.n_parts =*/ -1,
+ /*.gpu_layers =*/ 0,
/*.seed =*/ -1,
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
@@ -876,6 +880,7 @@ static void llama_model_load_internal(
const std::string & fname,
llama_context & lctx,
int n_ctx,
+ int n_gpu_layers,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -1022,6 +1027,33 @@ static void llama_model_load_internal(
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
model.mapping = std::move(ml->mapping);
+#ifdef GGML_USE_CUBLAS
+ {
+ const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+
+ fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
+
+ size_t vram_total = 0;
+
+ for (int i = 0; i < n_gpu; ++i) {
+ const auto & layer = model.layers[i];
+
+ ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
+ ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
+ ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
+ ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
+ ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
+ ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
+ ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
+ }
+ if (n_gpu_layers > (int) hparams.n_layer) {
+ fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
+ ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
+ }
+
+ fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
+ }
+#endif
// loading time will be recalculate after the first eval, so
// we take page faults deferred by mmap() into consideration
@@ -1032,6 +1064,7 @@ static bool llama_model_load(
const std::string & fname,
llama_context & lctx,
int n_ctx,
+ int n_gpu_layers,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -1039,7 +1072,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
+ llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::string & err) {
@@ -2111,7 +2144,7 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
+ if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
diff --git a/llama.h b/llama.h
index ca05645..21cba8c 100644
--- a/llama.h
+++ b/llama.h
@@ -54,9 +54,10 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
- int n_ctx; // text context
- int n_parts; // -1 for default
- int seed; // RNG seed, -1 for random
+ int n_ctx; // text context
+ int n_parts; // -1 for default
+ int n_gpu_layers; // number of layers to store in VRAM
+ int seed; // RNG seed, -1 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one