aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp4
-rw-r--r--examples/main/main.cpp4
-rw-r--r--examples/perplexity/perplexity.cpp70
3 files changed, 51 insertions, 27 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index f1c3bae..6af4402 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -438,8 +438,8 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
// TODO: not great allocating this every time
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
- std::vector<llama_token> res(text.size() + (int)add_bos);
- int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
+ std::vector<llama_token> res(text.size() + (int) add_bos);
+ const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 5ac151e..045093c 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -313,7 +313,8 @@ int main(int argc, char ** argv) {
if (n_past + (int) embd.size() > n_ctx) {
const int n_left = n_past - params.n_keep;
- n_past = params.n_keep;
+ // always keep the first token - BOS
+ n_past = std::max(1, params.n_keep);
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@@ -331,7 +332,6 @@ int main(int argc, char ** argv) {
}
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
- // REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 299a199..9212dee 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -25,46 +25,68 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
+ // BOS tokens will be added for each chunk before eval
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
- int count = 0;
- int seq_count = tokens.size() / params.n_ctx;
- int n_vocab = llama_n_vocab(ctx);
+ int count = 0;
+
+ const int n_chunk = tokens.size() / params.n_ctx;
+ const int n_vocab = llama_n_vocab(ctx);
+ const int n_batch = params.n_batch;
double nll = 0.0;
- fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
+ fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+
+ for (int i = 0; i < n_chunk; ++i) {
+ const int start = i * params.n_ctx;
+ const int end = start + params.n_ctx;
- for (int i = 0; i < seq_count; ++i) {
- int start = i * params.n_ctx;
- int end = start + params.n_ctx;
+ const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
- int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
- auto start_t = std::chrono::high_resolution_clock::now();
+
+ const auto t_start = std::chrono::high_resolution_clock::now();
+
for (int j = 0; j < num_batches; ++j) {
- int batch_start = start + j * params.n_batch;
- int batch_size = std::min(end - batch_start, params.n_batch);
- if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
+ const int batch_start = start + j * n_batch;
+ const int batch_size = std::min(end - batch_start, n_batch);
+
+ // save original token and restore it after eval
+ const auto token_org = tokens[batch_start];
+
+ // add BOS token for the first batch of each chunk
+ if (j == 0) {
+ tokens[batch_start] = llama_token_bos();
+ }
+
+ if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
- auto batch_logits = llama_get_logits(ctx);
+
+ // restore the original token in case it was set to BOS
+ tokens[batch_start] = token_org;
+
+ const auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
- auto end_t = std::chrono::high_resolution_clock::now();
+
+ const auto t_end = std::chrono::high_resolution_clock::now();
+
if (i == 0) {
- const float seconds = std::chrono::duration<float>(end_t - start_t).count();
- printf("%.2f seconds per pass - ETA ", seconds);
- int total_seconds = (int)(seconds * seq_count);
+ const float t_total = std::chrono::duration<float>(t_end - t_start).count();
+ fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
+ int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
- printf("%d hours ", total_seconds / (60*60));
+ fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
- printf("%d minutes\n", total_seconds / 60);
+ fprintf(stderr, "%d minutes\n", total_seconds / 60);
}
+
// We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
- // calculate the perplexity over the last half the window (so the model always has
+ // calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
@@ -76,10 +98,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// process the entire prompt.
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
- std::vector<float> tok_logits(
- logits.begin() + j * n_vocab,
+ const std::vector<float> tok_logits(
+ logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);
- float prob = softmax(tok_logits)[tokens[start + j + 1]];
+
+ const float prob = softmax(tok_logits)[tokens[start + j + 1]];
+
nll += -std::log(prob);
++count;
}