aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuciano <lucianostrika44@gmail.com>2023-03-24 08:05:13 -0700
committerGitHub <noreply@github.com>2023-03-24 17:05:13 +0200
commit8d4a855c241ecb0f3ddc03447fe56002ebf27a37 (patch)
tree4de329fb2849fb6128d05237850b8ceb7519bf36
parentb6b268d4415fd3b3e53f22b6619b724d4928f713 (diff)
Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--llama.cpp56
-rw-r--r--llama.h5
-rw-r--r--main.cpp23
-rw-r--r--utils.cpp4
-rw-r--r--utils.h4
5 files changed, 82 insertions, 10 deletions
diff --git a/llama.cpp b/llama.cpp
index d552192..d8c7715 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;
+
+ // input embedding (1-dimensional array: [n_embd])
+ std::vector<float> embedding;
};
struct llama_context_params llama_context_default_params() {
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
+ /*.embedding =*/ false,
};
return result;
@@ -592,8 +596,6 @@ static bool llama_model_load(
fin.close();
}
- lctx.logits.reserve(lctx.model.hparams.n_ctx);
-
lctx.t_load_us = ggml_time_us() - t_start_us;
return true;
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
inpL = cur;
}
+ // used at the end to optionally extract the embeddings
+ struct ggml_tensor * embeddings = NULL;
+
// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);
+
+ embeddings = inpL;
}
// lm_head
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- auto & logits_out = lctx.logits;
+ // extract logits
+ {
+ auto & logits_out = lctx.logits;
+
+ if (lctx.logits_all) {
+ logits_out.resize(n_vocab * N);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+ } else {
+ // return result for just the last token
+ logits_out.resize(n_vocab);
+ memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ }
+ }
+
+ // extract embeddings
+ if (lctx.embedding.size()) {
+ auto & embedding_out = lctx.embedding;
- if (lctx.logits_all) {
- logits_out.resize(n_vocab * N);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
- } else {
- // return result for just the last token
- logits_out.resize(n_vocab);
- memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+ embedding_out.resize(n_embd);
+ memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}
if (mem_per_token == 0) {
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
return nullptr;
}
+ // reserve memory for context buffers
+ {
+ const auto & hparams = ctx->model.hparams;
+ if (params.logits_all) {
+ ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+ } else {
+ ctx->logits.reserve(hparams.n_ctx);
+ }
+
+ if (params.embedding){
+ ctx->embedding.reserve(hparams.n_embd);
+ }
+ }
+
return ctx;
}
@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}
+float * llama_get_embeddings(struct llama_context * ctx) {
+ return ctx->embedding.data();
+}
+
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;
diff --git a/llama.h b/llama.h
index 3df9ed1..209b4db 100644
--- a/llama.h
+++ b/llama.h
@@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
+ bool embedding; // embedding mode only
};
LLAMA_API struct llama_context_params llama_context_default_params();
@@ -108,6 +109,10 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
+ // Get the embeddings for the input
+ // shape: [n_embd] (1-dimensional)
+ LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
+
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
diff --git a/main.cpp b/main.cpp
index 5ba6d5a..46a80ff 100644
--- a/main.cpp
+++ b/main.cpp
@@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
@@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
+
int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
+ if (params.embedding){
+ embd = embd_inp;
+
+ if (embd.size() > 0) {
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
+ }
+ }
+
+ const auto embeddings = llama_get_embeddings(ctx);
+
+ // TODO: print / use the embeddings
+
+ if (params.use_color) {
+ printf(ANSI_COLOR_RESET);
+ }
+
+ return 0;
+ }
+
while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {
diff --git a/utils.cpp b/utils.cpp
index 45c9cab..0df89af 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
+ } else if (arg == "--embedding") {
+ params.embedding = true;
+ } else if (arg == "--interactive-start") {
+ params.interactive = true;
} else if (arg == "--interactive-first") {
params.interactive_start = true;
} else if (arg == "-ins" || arg == "--instruct") {
diff --git a/utils.h b/utils.h
index b0de556..8120c12 100644
--- a/utils.h
+++ b/utils.h
@@ -32,13 +32,17 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
+
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
+
+ bool embedding = false; // get only sentence embedding
bool interactive_start = false; // wait for user input immediately
+
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt