aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp56
1 files changed, 46 insertions, 10 deletions
diff --git a/llama.cpp b/llama.cpp
index d552192..d8c7715 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;
+
+ // input embedding (1-dimensional array: [n_embd])
+ std::vector<float> embedding;
};
struct llama_context_params llama_context_default_params() {
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
+ /*.embedding =*/ false,
};
return result;
@@ -592,8 +596,6 @@ static bool llama_model_load(
fin.close();
}
- lctx.logits.reserve(lctx.model.hparams.n_ctx);
-
lctx.t_load_us = ggml_time_us() - t_start_us;
return true;
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
inpL = cur;
}
+ // used at the end to optionally extract the embeddings
+ struct ggml_tensor * embeddings = NULL;
+
// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);
+
+ embeddings = inpL;
}
// lm_head
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- auto & logits_out = lctx.logits;
+ // extract logits
+ {
+ auto & logits_out = lctx.logits;
+
+ if (lctx.logits_all) {
+ logits_out.resize(n_vocab * N);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+ } else {
+ // return result for just the last token
+ logits_out.resize(n_vocab);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ }
+ }
+
+ // extract embeddings
+ if (lctx.embedding.size()) {
+ auto & embedding_out = lctx.embedding;
- if (lctx.logits_all) {
- logits_out.resize(n_vocab * N);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- } else {
- // return result for just the last token
- logits_out.resize(n_vocab);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ embedding_out.resize(n_embd);
+ memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
if (mem_per_token == 0) {
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
return nullptr;
}
+ // reserve memory for context buffers
+ {
+ const auto & hparams = ctx->model.hparams;
+ if (params.logits_all) {
+ ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+ } else {
+ ctx->logits.reserve(hparams.n_ctx);
+ }
+
+ if (params.embedding){
+ ctx->embedding.reserve(hparams.n_embd);
+ }
+ }
+
return ctx;
}
@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}
+float * llama_get_embeddings(struct llama_context * ctx) {
+ return ctx->embedding.data();
+}
+
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;