aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/embedding/embedding.cpp15
-rw-r--r--examples/perplexity/perplexity.cpp8
-rw-r--r--llama.cpp22
-rw-r--r--llama.h1
4 files changed, 20 insertions, 26 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 3015293..d397f35 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -1,15 +1,6 @@
#include "common.h"
#include "llama.h"
-#include <cassert>
-#include <cinttypes>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <fstream>
-#include <string>
-#include <vector>
-
int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
@@ -94,9 +85,13 @@ int main(int argc, char ** argv) {
}
}
+ const int n_embd = llama_n_embd(ctx);
const auto embeddings = llama_get_embeddings(ctx);
- // TODO: print / use the embeddings
+ for (int i = 0; i < n_embd; i++) {
+ printf("%f ", embeddings[i]);
+ }
+ printf("\n");
}
llama_print_timings(ctx);
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index f0266a0..f617ba3 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -1,14 +1,6 @@
#include "common.h"
#include "llama.h"
-#include <cassert>
-#include <cinttypes>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <string>
-#include <vector>
-
std::vector<double> softmax(const std::vector<float>& logits) {
std::vector<double> probs(logits.size());
float max_logit = logits[0];
diff --git a/llama.cpp b/llama.cpp
index 0015ede..2bd5203 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k(
double repeat_penalty) {
auto & rng = lctx.rng;
- const auto & vocab = lctx.vocab;
- const auto & logits = lctx.logits;
+ const int n_logits = lctx.model.hparams.n_vocab;
- int n_logits = vocab.id_to_token.size();
+ const auto & logits = lctx.logits;
+ const auto * plogits = logits.data() + logits.size() - n_logits;
std::vector<std::pair<double, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits);
@@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
- if (logits[i] < 0.0) {
- logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
+ if (plogits[i] < 0.0) {
+ logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
- logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
+ logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
- logits_id.push_back(std::make_pair(logits[i]*scale, i));
+ logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
@@ -1677,6 +1677,8 @@ struct llama_context * llama_init_from_file(
}
const auto & hparams = ctx->model.hparams;
+
+ // resized during inference
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
} else {
@@ -1684,7 +1686,7 @@ struct llama_context * llama_init_from_file(
}
if (params.embedding){
- ctx->embedding.reserve(hparams.n_embd);
+ ctx->embedding.resize(hparams.n_embd);
}
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
@@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) {
return ctx->model.hparams.n_ctx;
}
+int llama_n_embd(struct llama_context * ctx) {
+ return ctx->model.hparams.n_embd;
+}
+
float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}
diff --git a/llama.h b/llama.h
index 827abc1..ebf55f4 100644
--- a/llama.h
+++ b/llama.h
@@ -109,6 +109,7 @@ extern "C" {
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
+ LLAMA_API int llama_n_embd (struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row