aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp175
1 files changed, 143 insertions, 32 deletions
diff --git a/llama.cpp b/llama.cpp
index 73f6860..b992321 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -59,6 +59,12 @@ static const size_t MB = 1024*1024;
// TODO: dynamically determine these sizes
// needs modifications in ggml
+typedef void (*offload_func_t)(struct ggml_tensor * tensor);
+
+void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
+ (void) tensor;
+}
+
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
{
static std::map<e_model, size_t> k_sizes = {
@@ -173,6 +179,7 @@ struct llama_model {
struct ggml_tensor * output;
std::vector<llama_layer> layers;
+ int n_gpu_layers;
// context
struct ggml_context * ctx = NULL;
@@ -198,6 +205,12 @@ struct llama_model {
if (ctx) {
ggml_free(ctx);
}
+
+#ifdef GGML_USE_CUBLAS
+ for (size_t i = 0; i < tensors_by_name.size(); ++i) {
+ ggml_cuda_free_data(tensors_by_name[i].second);
+ }
+#endif // GGML_USE_CUBLAS
}
};
@@ -698,6 +711,7 @@ struct llama_model_loader {
}
ggml_set_name(tensor, lt.name.c_str());
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
+
tensor->backend = backend;
lt.ggml_tensor = tensor;
num_ggml_tensors_created++;
@@ -850,7 +864,10 @@ static bool kv_cache_init(
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.n_ctx =*/ 512,
+ /*.n_batch =*/ 512,
/*.gpu_layers =*/ 0,
+ /*.main_gpu =*/ 0,
+ /*.tensor_split =*/ {0},
/*.seed =*/ -1,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
@@ -944,7 +961,10 @@ static void llama_model_load_internal(
const std::string & fname,
llama_context & lctx,
int n_ctx,
+ int n_batch,
int n_gpu_layers,
+ int main_gpu,
+ const float * tensor_split,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -959,6 +979,7 @@ static void llama_model_load_internal(
lctx.vocab = std::move(ml->file_loaders.at(0)->vocab);
auto & model = lctx.model;
model.hparams = ml->file_loaders.at(0)->hparams;
+ model.n_gpu_layers = n_gpu_layers;
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
auto & hparams = model.hparams;
@@ -1039,17 +1060,22 @@ static void llama_model_load_internal(
}
#if defined(GGML_USE_CUBLAS)
-#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
+ ggml_cuda_set_main_device(main_gpu);
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
#elif defined(GGML_USE_CLBLAST)
-#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CL
fprintf(stderr, "%s: using OpenCL for GPU acceleration\n", __func__);
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU
#else
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
#endif
// prepare memory for the weights
- size_t vram_total = 0;
+ size_t vram_weights = 0;
+ size_t vram_scratch = 0;
{
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
@@ -1064,7 +1090,7 @@ static void llama_model_load_internal(
{
ggml_backend backend_output;
if (n_gpu_layers > int(n_layer)) { // NOLINT
- backend_output = LLAMA_BACKEND_OFFLOAD;
+ backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else {
backend_output = GGML_BACKEND_CPU;
}
@@ -1076,7 +1102,8 @@ static void llama_model_load_internal(
model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
- const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+ const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
auto & layer = model.layers[i];
@@ -1084,19 +1111,19 @@ static void llama_model_load_internal(
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
- layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
- layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
- layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
- layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
+ layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
+ layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split);
+ layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split);
+ layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
- layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend);
- layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend);
- layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend);
+ layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split);
+ layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split);
+ layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split);
- if (backend == LLAMA_BACKEND_OFFLOAD) {
- vram_total +=
+ if (backend == GGML_BACKEND_GPU) {
+ vram_weights +=
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
@@ -1113,7 +1140,7 @@ static void llama_model_load_internal(
// this is the total memory required to run the inference
const size_t mem_required =
ctx_size +
- mmapped_size - vram_total + // weights in VRAM not in memory
+ mmapped_size - vram_weights + // weights in VRAM not in memory
MEM_REQ_SCRATCH0().at(model.type) +
MEM_REQ_SCRATCH1().at(model.type) +
MEM_REQ_EVAL().at (model.type);
@@ -1127,12 +1154,21 @@ static void llama_model_load_internal(
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+#ifdef GGML_USE_CUBLAS
+ vram_scratch = n_batch * MB;
+ ggml_cuda_set_scratch_size(vram_scratch);
+ if (n_gpu_layers > 0) {
+ fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
+ __func__, vram_scratch / MB);
+ }
+#endif // GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
fprintf(stderr, "%s: offloading %d layers to GPU\n", __func__, n_gpu);
if (n_gpu_layers > (int) hparams.n_layer) {
fprintf(stderr, "%s: offloading output layer to GPU\n", __func__);
}
- fprintf(stderr, "%s: total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
+ fprintf(stderr, "%s: total VRAM used: %zu MB\n",
+ __func__, (vram_weights + vram_scratch + MB - 1) / MB); // round up
#else
(void) n_gpu_layers;
#endif
@@ -1147,6 +1183,8 @@ static void llama_model_load_internal(
#if defined(GGML_USE_CUBLAS)
{
+ ggml_cuda_set_tensor_split(tensor_split);
+
size_t done_size = 0;
size_t data_size = 0;
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
@@ -1156,7 +1194,8 @@ static void llama_model_load_internal(
}
}
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
- if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
+ ggml_backend backend = lt.ggml_tensor->backend;
+ if (backend != GGML_BACKEND_GPU && backend != GGML_BACKEND_GPU_SPLIT) {
continue;
}
if (progress_callback) {
@@ -1177,7 +1216,7 @@ static void llama_model_load_internal(
}
}
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
- if (lt.ggml_tensor->backend != GGML_BACKEND_CL) {
+ if (lt.ggml_tensor->backend != GGML_BACKEND_GPU) {
continue;
}
if (progress_callback) {
@@ -1187,6 +1226,9 @@ static void llama_model_load_internal(
done_size += lt.size;
}
}
+#else
+ (void) n_batch;
+ (void) tensor_split;
#endif
if (progress_callback) {
@@ -1204,7 +1246,10 @@ static bool llama_model_load(
const std::string & fname,
llama_context & lctx,
int n_ctx,
+ int n_batch,
int n_gpu_layers,
+ int main_gpu,
+ float * tensor_split,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -1212,8 +1257,8 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
- vocab_only, progress_callback, progress_callback_user_data);
+ llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type,
+ use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::exception & err) {
fprintf(stderr, "error loading model: %s\n", err.what());
@@ -1254,12 +1299,13 @@ static bool llama_eval_internal(
LLAMA_ASSERT(!!kv_self.ctx);
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
- const int n_head = hparams.n_head;
- const int n_vocab = hparams.n_vocab;
- const int n_rot = hparams.n_embd/hparams.n_head;
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = hparams.n_ctx;
+ const int n_head = hparams.n_head;
+ const int n_vocab = hparams.n_vocab;
+ const int n_rot = hparams.n_embd/hparams.n_head;
+ const int n_gpu_layers = model.n_gpu_layers;
auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute;
@@ -1284,7 +1330,17 @@ static bool llama_eval_internal(
struct ggml_tensor * cur;
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
+ const int i_gpu_start = n_layer - n_gpu_layers;
+
for (int il = 0; il < n_layer; ++il) {
+ offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (il >= i_gpu_start) {
+ offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
+ }
+#endif // GGML_USE_CUBLAS
+
struct ggml_tensor * inpSA = inpL;
lctx.use_buf(ctx0, 0);
@@ -1292,20 +1348,32 @@ static bool llama_eval_internal(
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_0");
// cur = cur*attention_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "attention_norm_0");
}
// self-attention
{
// compute Q and K and RoPE them
+ struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ // offload_func(tmpq);
+ ggml_set_name(tmpq, "tmpq");
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
- ggml_set_name(Qcur, "Qcur");
+ struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ // offload_func(tmpk);
+ ggml_set_name(tmpk, "tmpk");
+
+ struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
ggml_set_name(Kcur, "Kcur");
+ struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
+ ggml_set_name(Qcur, "Qcur");
+
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
@@ -1313,9 +1381,11 @@ static bool llama_eval_internal(
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+ ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+ ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
@@ -1390,63 +1460,104 @@ static bool llama_eval_internal(
cur = ggml_mul_mat(ctx0,
model.layers[il].wo,
cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_wo");
}
lctx.use_buf(ctx0, 1);
+ //ggml_cuda_set_scratch(1);
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+ offload_func(inpFF);
+ ggml_set_name(inpFF, "inpFF");
// feed-forward network
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_1");
// cur = cur*ffn_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "ffn_norm");
}
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model.layers[il].w3,
cur);
+ offload_func(tmp);
+ ggml_set_name(tmp, "result_w3");
cur = ggml_mul_mat(ctx0,
model.layers[il].w1,
cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w2");
// SILU activation
cur = ggml_silu(ctx0, cur);
+ offload_func(cur);
+ ggml_set_name(cur, "silu");
cur = ggml_mul(ctx0, cur, tmp);
+ offload_func(cur);
+ ggml_set_name(cur, "silu_x_result_w3");
cur = ggml_mul_mat(ctx0,
model.layers[il].w2,
cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w2");
}
cur = ggml_add(ctx0, cur, inpFF);
+ offload_func(cur);
+ ggml_set_name(cur, "inpFF_+_result_w2");
// input for next layer
inpL = cur;
+
}
lctx.use_buf(ctx0, 0);
+ //ggml_cuda_set_scratch(0);
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;
+ offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (n_gpu_layers > n_layer) {
+ offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
+ }
+#endif // GGML_USE_CUBLAS
+
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_inpL");
+
+ cur = ggml_rms_norm(ctx0, cur);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_after");
// cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.norm);
+ offload_func(cur);
+ ggml_set_name(cur, "result_norm");
embeddings = cur;
}
+
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
+ ggml_set_name(cur, "result_output");
lctx.use_buf(ctx0, -1);
@@ -2366,9 +2477,9 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
- params.use_mmap, params.use_mlock, params.vocab_only,
- params.progress_callback, params.progress_callback_user_data)) {
+ if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers,
+ params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock,
+ params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);
return nullptr;