diff options
author | Luciano <lucianostrika44@gmail.com> | 2023-03-24 08:05:13 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-24 17:05:13 +0200 |
commit | 8d4a855c241ecb0f3ddc03447fe56002ebf27a37 (patch) | |
tree | 4de329fb2849fb6128d05237850b8ceb7519bf36 | |
parent | b6b268d4415fd3b3e53f22b6619b724d4928f713 (diff) |
Add embedding mode with arg flag. Currently working (#282)
* working but ugly
* add arg flag, not working on embedding mode
* typo
* Working! Thanks to @nullhook
* make params argument instead of hardcoded boolean. remove useless time check
* start doing the instructions but not finished. This probably doesnt compile
* Embeddings extraction support
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r-- | llama.cpp | 56 | ||||
-rw-r--r-- | llama.h | 5 | ||||
-rw-r--r-- | main.cpp | 23 | ||||
-rw-r--r-- | utils.cpp | 4 | ||||
-rw-r--r-- | utils.h | 4 |
5 files changed, 82 insertions, 10 deletions
@@ -102,6 +102,9 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector<float> logits; bool logits_all = false; + + // input embedding (1-dimensional array: [n_embd]) + std::vector<float> embedding; }; struct llama_context_params llama_context_default_params() { @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, + /*.embedding =*/ false, }; return result; @@ -592,8 +596,6 @@ static bool llama_model_load( fin.close(); } - lctx.logits.reserve(lctx.model.hparams.n_ctx); - lctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -791,6 +793,9 @@ static bool llama_eval_internal( inpL = cur; } + // used at the end to optionally extract the embeddings + struct ggml_tensor * embeddings = NULL; + // norm { inpL = ggml_rms_norm(ctx0, inpL); @@ -799,6 +804,8 @@ static bool llama_eval_internal( inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); + + embeddings = inpL; } // lm_head @@ -821,15 +828,26 @@ static bool llama_eval_internal( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - auto & logits_out = lctx.logits; + // extract logits + { + auto & logits_out = lctx.logits; + + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } + } + + // extract embeddings + if (lctx.embedding.size()) { + auto & embedding_out = lctx.embedding; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) { @@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file( return nullptr; } + // reserve memory for context buffers + { + const auto & hparams = ctx->model.hparams; + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_ctx); + } + + if (params.embedding){ + ctx->embedding.reserve(hparams.n_embd); + } + } + return ctx; } @@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { if (token >= llama_n_vocab(ctx)) { return nullptr; @@ -53,6 +53,7 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights + bool embedding; // embedding mode only }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -108,6 +109,10 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); @@ -199,6 +199,7 @@ int main(int argc, char ** argv) { lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; ctx = llama_init_from_file(params.model.c_str(), lparams); @@ -292,6 +293,7 @@ int main(int argc, char ** argv) { std::vector<llama_token> embd; + int last_n_size = params.repeat_last_n; std::vector<llama_token> last_n_tokens(last_n_size); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -324,6 +326,27 @@ int main(int argc, char ** argv) { // the first thing we will do is to output the prompt, so set color accordingly set_console_state(CONSOLE_STATE_PROMPT); + if (params.embedding){ + embd = embd_inp; + + if (embd.size() > 0) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + } + + const auto embeddings = llama_get_embeddings(ctx); + + // TODO: print / use the embeddings + + if (params.use_color) { + printf(ANSI_COLOR_RESET); + } + + return 0; + } + while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { @@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.model = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; + } else if (arg == "--embedding") { + params.embedding = true; + } else if (arg == "--interactive-start") { + params.interactive = true; } else if (arg == "--interactive-first") { params.interactive_start = true; } else if (arg == "-ins" || arg == "--instruct") { @@ -32,13 +32,17 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; + std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted bool memory_f16 = false; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + + bool embedding = false; // get only sentence embedding bool interactive_start = false; // wait for user input immediately + bool instruct = false; // instruction mode (used for Alpaca models) bool ignore_eos = false; // do not stop generating after eos bool perplexity = false; // compute perplexity over the prompt |