diff options
| -rw-r--r-- | examples/embedding/embedding.cpp | 15 | ||||
| -rw-r--r-- | examples/perplexity/perplexity.cpp | 8 | ||||
| -rw-r--r-- | llama.cpp | 22 | ||||
| -rw-r--r-- | llama.h | 1 | 
4 files changed, 20 insertions, 26 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3015293..d397f35 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -1,15 +1,6 @@  #include "common.h"  #include "llama.h" -#include <cassert> -#include <cinttypes> -#include <cmath> -#include <cstdio> -#include <cstring> -#include <fstream> -#include <string> -#include <vector> -  int main(int argc, char ** argv) {      gpt_params params;      params.model = "models/llama-7B/ggml-model.bin"; @@ -94,9 +85,13 @@ int main(int argc, char ** argv) {              }          } +        const int n_embd = llama_n_embd(ctx);          const auto embeddings = llama_get_embeddings(ctx); -        // TODO: print / use the embeddings +        for (int i = 0; i < n_embd; i++) { +            printf("%f ", embeddings[i]); +        } +        printf("\n");      }      llama_print_timings(ctx); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f0266a0..f617ba3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,14 +1,6 @@  #include "common.h"  #include "llama.h" -#include <cassert> -#include <cinttypes> -#include <cmath> -#include <cstdio> -#include <cstring> -#include <string> -#include <vector> -  std::vector<double> softmax(const std::vector<float>& logits) {      std::vector<double> probs(logits.size());      float max_logit = logits[0]; @@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k(          double repeat_penalty) {      auto & rng = lctx.rng; -    const auto & vocab = lctx.vocab; -    const auto & logits = lctx.logits; +    const int n_logits = lctx.model.hparams.n_vocab; -    int n_logits = vocab.id_to_token.size(); +    const auto & logits = lctx.logits; +    const auto * plogits = logits.data() + logits.size() - n_logits;      std::vector<std::pair<double, llama_vocab::id>> logits_id;      logits_id.reserve(n_logits); @@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k(              // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main              if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {                  // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability -                if (logits[i] < 0.0) { -                    logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); +                if (plogits[i] < 0.0) { +                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));                  } else { -                    logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); +                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));                  }              } else { -                logits_id.push_back(std::make_pair(logits[i]*scale, i)); +                logits_id.push_back(std::make_pair(plogits[i]*scale, i));              }          }      } @@ -1677,6 +1677,8 @@ struct llama_context * llama_init_from_file(          }          const auto & hparams = ctx->model.hparams; + +        // resized during inference          if (params.logits_all) {              ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);          } else { @@ -1684,7 +1686,7 @@ struct llama_context * llama_init_from_file(          }          if (params.embedding){ -            ctx->embedding.reserve(hparams.n_embd); +            ctx->embedding.resize(hparams.n_embd);          }          ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type)); @@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) {      return ctx->model.hparams.n_ctx;  } +int llama_n_embd(struct llama_context * ctx) { +    return ctx->model.hparams.n_embd; +} +  float * llama_get_logits(struct llama_context * ctx) {      return ctx->logits.data();  } @@ -109,6 +109,7 @@ extern "C" {      LLAMA_API int llama_n_vocab(struct llama_context * ctx);      LLAMA_API int llama_n_ctx  (struct llama_context * ctx); +    LLAMA_API int llama_n_embd (struct llama_context * ctx);      // Token logits obtained from the last call to llama_eval()      // The logits for the last token are stored in the last row  | 
