aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorLuciano <lucianostrika44@gmail.com>2023-03-24 08:05:13 -0700
committerGitHub <noreply@github.com>2023-03-24 17:05:13 +0200
commit8d4a855c241ecb0f3ddc03447fe56002ebf27a37 (patch)
tree4de329fb2849fb6128d05237850b8ceb7519bf36 /llama.cpp
parentb6b268d4415fd3b3e53f22b6619b724d4928f713 (diff)
Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp56
1 files changed, 46 insertions, 10 deletions
diff --git a/llama.cpp b/llama.cpp
index d552192..d8c7715 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;
+
+ // input embedding (1-dimensional array: [n_embd])
+ std::vector<float> embedding;
};
struct llama_context_params llama_context_default_params() {
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
+ /*.embedding =*/ false,
};
return result;
@@ -592,8 +596,6 @@ static bool llama_model_load(
fin.close();
}
- lctx.logits.reserve(lctx.model.hparams.n_ctx);
-
lctx.t_load_us = ggml_time_us() - t_start_us;
return true;
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
inpL = cur;
}
+ // used at the end to optionally extract the embeddings
+ struct ggml_tensor * embeddings = NULL;
+
// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);
+
+ embeddings = inpL;
}
// lm_head
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- auto & logits_out = lctx.logits;
+ // extract logits
+ {
+ auto & logits_out = lctx.logits;
+
+ if (lctx.logits_all) {
+ logits_out.resize(n_vocab * N);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+ } else {
+ // return result for just the last token
+ logits_out.resize(n_vocab);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ }
+ }
+
+ // extract embeddings
+ if (lctx.embedding.size()) {
+ auto & embedding_out = lctx.embedding;
- if (lctx.logits_all) {
- logits_out.resize(n_vocab * N);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- } else {
- // return result for just the last token
- logits_out.resize(n_vocab);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ embedding_out.resize(n_embd);
+ memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
if (mem_per_token == 0) {
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
return nullptr;
}
+ // reserve memory for context buffers
+ {
+ const auto & hparams = ctx->model.hparams;
+ if (params.logits_all) {
+ ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+ } else {
+ ctx->logits.reserve(hparams.n_ctx);
+ }
+
+ if (params.embedding){
+ ctx->embedding.reserve(hparams.n_embd);
+ }
+ }
+
return ctx;
}
@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}
+float * llama_get_embeddings(struct llama_context * ctx) {
+ return ctx->embedding.data();
+}
+
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;