aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp326
1 files changed, 265 insertions, 61 deletions
diff --git a/llama.cpp b/llama.cpp
index 9a93409..9d48ccd 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -5,12 +5,25 @@
#include <cinttypes>
#include <fstream>
#include <random>
+#include <map>
#include <unordered_map>
#include <queue>
#include <regex>
#include <cassert>
#include <cstring>
+#define LLAMA_USE_SCRATCH
+#define LLAMA_MAX_SCRATCH_BUFFERS 16
+
+#define LLAMA_ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+ abort(); \
+ } \
+ } while (0)
+
+
// determine number of model parts based on the dimension
static const std::unordered_map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 },
@@ -19,6 +32,52 @@ static const std::unordered_map<int, int> LLAMA_N_PARTS = {
{ 8192, 8 },
};
+// available llama models
+enum e_model {
+ MODEL_UNKNOWN,
+ MODEL_7B,
+ MODEL_13B,
+ MODEL_30B,
+ MODEL_65B,
+};
+
+static const size_t MB = 1024*1024;
+
+// computed for n_ctx == 2048
+// TODO: dynamically determine these sizes
+// needs modifications in ggml
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
+ { MODEL_7B, 512ull*MB },
+ { MODEL_13B, 512ull*MB },
+ { MODEL_30B, 512ull*MB },
+ { MODEL_65B, 512ull*MB },
+};
+
+static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
+ { MODEL_7B, 512ull*MB },
+ { MODEL_13B, 512ull*MB },
+ { MODEL_30B, 512ull*MB },
+ { MODEL_65B, 512ull*MB },
+};
+
+// 2*n_embd*n_ctx*n_layer*sizeof(float16)
+static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
+ { MODEL_7B, 1026ull*MB },
+ { MODEL_13B, 1608ull*MB },
+ { MODEL_30B, 3124ull*MB },
+ { MODEL_65B, 5120ull*MB },
+};
+
+// this is mostly needed for temporary mul_mat buffers to dequantize the data
+// not actually needed if BLAS is disabled
+static const std::map<e_model, size_t> MEM_REQ_EVAL = {
+ { MODEL_7B, 768ull*MB },
+ { MODEL_13B, 1024ull*MB },
+ { MODEL_30B, 1280ull*MB },
+ { MODEL_65B, 1536ull*MB },
+};
+
// default hparams (LLaMA 7B)
struct llama_hparams {
int32_t n_vocab = 32000;
@@ -50,7 +109,20 @@ struct llama_layer {
struct ggml_tensor * w3;
};
+struct llama_kv_cache {
+ struct ggml_tensor * k;
+ struct ggml_tensor * v;
+
+ struct ggml_context * ctx;
+
+ std::vector<uint8_t> buf;
+
+ int n; // number of tokens currently in the cache
+};
+
struct llama_model {
+ e_model type = MODEL_UNKNOWN;
+
llama_hparams hparams;
struct ggml_tensor * tok_embeddings;
@@ -60,12 +132,18 @@ struct llama_model {
std::vector<llama_layer> layers;
- // key + value memory
- struct ggml_tensor * memory_k;
- struct ggml_tensor * memory_v;
-
- //
+ // context
struct ggml_context * ctx;
+
+ // key + value cache for the self attention
+ // TODO: move to llama_state
+ struct llama_kv_cache kv_self;
+
+ // the model memory buffer
+ std::vector<uint8_t> buf;
+
+ // tensors
+ int n_loaded;
std::unordered_map<std::string, struct ggml_tensor *> tensors;
};
@@ -105,8 +183,88 @@ struct llama_context {
// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;
+
+ // memory buffers used to evaluate the model
+ // TODO: move in llama_state
+ std::vector<uint8_t> buf_compute;
+ std::vector<uint8_t> buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
+
+ int buf_last = 0;
+ size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
+
+ void use_buf(struct ggml_context * ctx, int i) {
+#if defined(LLAMA_USE_SCRATCH)
+ size_t last_size = 0;
+
+ if (i == -1) {
+ last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
+ } else {
+ auto & buf = buf_scratch[i];
+ last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
+ }
+
+ if (buf_last >= 0) {
+ buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
+ }
+
+ buf_last = i;
+#else
+ (void) i;
+ (void) ctx;
+#endif
+ }
+
+ size_t get_buf_max_mem(int i) const {
+#if defined(LLAMA_USE_SCRATCH)
+ return buf_max_size[i];
+#else
+ (void) i;
+ return 0;
+#endif
+ }
};
+//
+// kv cache
+//
+
+static bool kv_cache_init(
+ const struct llama_hparams & hparams,
+ struct llama_kv_cache & cache,
+ ggml_type wtype,
+ int n_ctx) {
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+
+ const int n_mem = n_layer*n_ctx;
+ const int n_elements = n_embd*n_mem;
+
+ cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB);
+
+ struct ggml_init_params params;
+ params.mem_size = cache.buf.size();
+ params.mem_buffer = cache.buf.data();
+
+ cache.ctx = ggml_init(params);
+
+ if (!cache.ctx) {
+ fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
+ return false;
+ }
+
+ cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+ cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+
+ return true;
+}
+
+static void kv_cache_free(struct llama_kv_cache & cache) {
+ if (cache.ctx) {
+ ggml_free(cache.ctx);
+ cache.ctx = nullptr;
+ }
+}
+
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.n_ctx =*/ 512,
@@ -204,6 +362,22 @@ static bool llama_model_load(
fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
}
+ if (hparams.n_layer == 32) {
+ model.type = e_model::MODEL_7B;
+ }
+
+ if (hparams.n_layer == 40) {
+ model.type = e_model::MODEL_13B;
+ }
+
+ if (hparams.n_layer == 60) {
+ model.type = e_model::MODEL_30B;
+ }
+
+ if (hparams.n_layer == 80) {
+ model.type = e_model::MODEL_65B;
+ }
+
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
fprintf(stderr, "%s: n_embd = %d\n", __func__, hparams.n_embd);
@@ -214,6 +388,7 @@ static bool llama_model_load(
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: n_ff = %d\n", __func__, n_ff);
fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
+ fprintf(stderr, "%s: type = %d\n", __func__, model.type);
}
// load vocab
@@ -307,11 +482,32 @@ static bool llama_model_load(
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
+ // print memory requirements
+ {
+ const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
+
+ // this is the total memory required to run the inference
+ const size_t mem_required =
+ ctx_size +
+ MEM_REQ_SCRATCH0.at(model.type) +
+ MEM_REQ_SCRATCH1.at(model.type) +
+ MEM_REQ_EVAL.at (model.type);
+
+ // this is the memory required by one llama_state
+ const size_t mem_required_state =
+ scale*MEM_REQ_KV_SELF.at(model.type);
+
+ fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
+ mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
+ }
+
// create the ggml context
{
+ lctx.model.buf.resize(ctx_size);
+
struct ggml_init_params params = {
- /*.mem_size =*/ ctx_size,
- /*.mem_buffer =*/ NULL,
+ /*.mem_size =*/ lctx.model.buf.size(),
+ /*.mem_buffer =*/ lctx.model.buf.data(),
};
model.ctx = ggml_init(params);
@@ -374,25 +570,6 @@ static bool llama_model_load(
}
}
- // key + value memory
- {
- const auto & hparams = model.hparams;
-
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
-
- const int n_mem = n_layer*n_ctx;
- const int n_elements = n_embd*n_mem;
-
- model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements);
-
- const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
-
- fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
- }
-
const size_t file_offset = fin.tellg();
fin.close();
@@ -416,9 +593,10 @@ static bool llama_model_load(
// load weights
{
- int n_tensors = 0;
size_t total_size = 0;
+ model.n_loaded = 0;
+
fprintf(stderr, "%s: ", __func__);
while (true) {
@@ -583,7 +761,10 @@ static bool llama_model_load(
}
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
- if (++n_tensors % 8 == 0) {
+ model.n_loaded++;
+
+ // progress
+ if (model.n_loaded % 8 == 0) {
fprintf(stderr, ".");
fflush(stderr);
}
@@ -591,7 +772,13 @@ static bool llama_model_load(
fprintf(stderr, " done\n");
- fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+ fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
+ if (model.n_loaded == 0) {
+ fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
+ } else if (model.n_loaded != (int) model.tensors.size()) {
+ fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
+ return false;
+ }
}
fin.close();
@@ -622,6 +809,10 @@ static bool llama_eval_internal(
const auto & model = lctx.model;
const auto & hparams = model.hparams;
+ auto & kv_self = model.kv_self;
+
+ LLAMA_ASSERT(!!kv_self.ctx);
+
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
@@ -630,27 +821,11 @@ static bool llama_eval_internal(
const int n_rot = hparams.n_embd/hparams.n_head;
auto & mem_per_token = lctx.mem_per_token;
-
- // TODO: fix this hardcoded size
- static size_t buf_size = 2048u*1024*1024; // TMP !!!
- static void * buf = malloc(buf_size);
-
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
- const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
- //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
-
- // reallocate
- buf_size = buf_size_new;
- buf = realloc(buf, buf_size);
- if (buf == nullptr) {
- fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
- return false;
- }
- }
+ auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
- /*.mem_size =*/ buf_size,
- /*.mem_buffer =*/ buf,
+ /*.mem_size =*/ buf_compute.size(),
+ /*.mem_buffer =*/ buf_compute.data(),
};
struct ggml_context * ctx0 = ggml_init(params);
@@ -667,6 +842,8 @@ static bool llama_eval_internal(
struct ggml_tensor * cur;
+ lctx.use_buf(ctx0, 0);
+
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
@@ -685,8 +862,8 @@ static bool llama_eval_internal(
// store key and value to memory
if (N >= 1) {
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -707,7 +884,7 @@ static bool llama_eval_internal(
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
n_past, n_rot, 1),
0, 2, 1, 3);
@@ -733,7 +910,7 @@ static bool llama_eval_internal(
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+ ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
@@ -755,6 +932,8 @@ static bool llama_eval_internal(
cur);
}
+ lctx.use_buf(ctx0, 1);
+
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
// feed-forward network
@@ -773,7 +952,6 @@ static bool llama_eval_internal(
model.layers[il].w3,
cur);
-
cur = ggml_mul_mat(ctx0,
model.layers[il].w1,
cur);
@@ -788,17 +966,20 @@ static bool llama_eval_internal(
cur);
}
- cur = ggml_add(ctx0, cur, inpFF);
+ cur = ggml_add(ctx0, cur, inpFF);
// input for next layer
inpL = cur;
}
+ lctx.use_buf(ctx0, 0);
+
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;
// norm
{
+
inpL = ggml_rms_norm(ctx0, inpL);
// inpL = norm*inpL
@@ -810,9 +991,9 @@ static bool llama_eval_internal(
}
// lm_head
- {
- inpL = ggml_mul_mat(ctx0, model.output, inpL);
- }
+ inpL = ggml_mul_mat(ctx0, model.output, inpL);
+
+ lctx.use_buf(ctx0, -1);
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
@@ -854,7 +1035,13 @@ static bool llama_eval_internal(
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
- //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
+
+#if 0
+ printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
+ ggml_used_mem(ctx0)/1024.0/1024.0,
+ lctx.get_buf_max_mem(0)/1024.0/1024.0,
+ lctx.get_buf_max_mem(1)/1024.0/1024.0);
+#endif
ggml_free(ctx0);
@@ -1427,9 +1614,9 @@ struct llama_context * llama_init_from_file(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
- ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
+ ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory,
+ if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
params.vocab_only)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);
@@ -1448,6 +1635,17 @@ struct llama_context * llama_init_from_file(
// reserve memory for context buffers
{
+ if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
+ fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
+ llama_free(ctx);
+ return nullptr;
+ }
+
+ {
+ const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
+ fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
+ }
+
const auto & hparams = ctx->model.hparams;
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
@@ -1458,12 +1656,19 @@ struct llama_context * llama_init_from_file(
if (params.embedding){
ctx->embedding.reserve(hparams.n_embd);
}
+
+ ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
+
+ ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
+ ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
}
return ctx;
}
void llama_free(struct llama_context * ctx) {
+ kv_cache_free(ctx->model.kv_self);
+
if (ctx->model.ctx) {
ggml_free(ctx->model.ctx);
}
@@ -1619,4 +1824,3 @@ const char * llama_print_system_info(void) {
return s.c_str();
}
-