From 8d4a855c241ecb0f3ddc03447fe56002ebf27a37 Mon Sep 17 00:00:00 2001 From: Luciano Date: Fri, 24 Mar 2023 08:05:13 -0700 Subject: Add embedding mode with arg flag. Currently working (#282) * working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov --- llama.cpp | 56 ++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 10 deletions(-) (limited to 'llama.cpp') diff --git a/llama.cpp b/llama.cpp index d552192..d8c7715 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,6 +102,9 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; bool logits_all = false; + + // input embedding (1-dimensional array: [n_embd]) + std::vector embedding; }; struct llama_context_params llama_context_default_params() { @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, + /*.embedding =*/ false, }; return result; @@ -592,8 +596,6 @@ static bool llama_model_load( fin.close(); } - lctx.logits.reserve(lctx.model.hparams.n_ctx); - lctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -791,6 +793,9 @@ static bool llama_eval_internal( inpL = cur; } + // used at the end to optionally extract the embeddings + struct ggml_tensor * embeddings = NULL; + // norm { inpL = ggml_rms_norm(ctx0, inpL); @@ -799,6 +804,8 @@ static bool llama_eval_internal( inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); + + embeddings = inpL; } // lm_head @@ -821,15 +828,26 @@ static bool llama_eval_internal( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - auto & logits_out = lctx.logits; + // extract logits + { + auto & logits_out = lctx.logits; + + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } + } + + // extract embeddings + if (lctx.embedding.size()) { + auto & embedding_out = lctx.embedding; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) { @@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file( return nullptr; } + // reserve memory for context buffers + { + const auto & hparams = ctx->model.hparams; + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_ctx); + } + + if (params.embedding){ + ctx->embedding.reserve(hparams.n_embd); + } + } + return ctx; } @@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { if (token >= llama_n_vocab(ctx)) { return nullptr; -- cgit v1.2.3