aboutsummaryrefslogtreecommitdiff
path: root/examples/perplexity
diff options
context:
space:
mode:
authorklosax <131523366+klosax@users.noreply.github.com>2023-07-22 14:21:24 +0200
committerGitHub <noreply@github.com>2023-07-22 14:21:24 +0200
commitb5fe67f8c69113bd9354bc1adcfe2df6be323740 (patch)
tree705c322554824091c35e885214ae19462773eab0 /examples/perplexity
parent24baa54ac1ff3d4156a2360deb1473af04a9b1a2 (diff)
Perplexity: Compute scores correlated to HellaSwag (#2312)
* Add parameter --perplexity-lines to perplexity.cpp
Diffstat (limited to 'examples/perplexity')
-rw-r--r--examples/perplexity/perplexity.cpp78
1 files changed, 77 insertions, 1 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index bfad999..d23b7e7 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -4,6 +4,7 @@
#include <cmath>
#include <ctime>
+#include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
printf("\n");
}
+void perplexity_lines(llama_context * ctx, const gpt_params & params) {
+ // Calculates perplexity over each line of the prompt
+
+ std::vector<std::string> prompt_lines;
+ std::istringstream strstream(params.prompt);
+ std::string line;
+
+ while (std::getline(strstream,line,'\n')) {
+ prompt_lines.push_back(line);
+ }
+
+ const int n_vocab = llama_n_vocab(ctx);
+
+ int counttotal = 0;
+ size_t n_lines = prompt_lines.size();
+
+ double nll = 0.0;
+
+ fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
+
+ printf("\nLine\tPPL line\tPPL cumulative\n");
+
+ for (size_t i = 0; i < n_lines; ++i) {
+
+ // Tokenize and insert BOS at start
+ std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
+
+ size_t batch_size = batch_embd.size();
+
+ // Stop if line is too long
+ if( batch_size > (size_t)params.n_ctx ) {
+ fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
+ return;
+ }
+
+ if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return;
+ }
+
+ const auto batch_logits = llama_get_logits(ctx);
+ std::vector<float> logits;
+ logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+
+ double nllline = 0.0;
+ int countline = 0;
+
+ // Perplexity over second half of the line
+ for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
+ // Calculate probability of next token, given the previous ones.
+ const std::vector<float> tok_logits(
+ logits.begin() + (j + 0) * n_vocab,
+ logits.begin() + (j + 1) * n_vocab);
+
+ const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
+
+ nllline += -std::log(prob);
+ ++countline;
+ }
+
+ nll += nllline;
+ counttotal += countline;
+
+ // perplexity is e^(average negative log-likelihood)
+ printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
+ fflush(stdout);
+ }
+
+ printf("\n");
+}
+
int main(int argc, char ** argv) {
gpt_params params;
@@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
- perplexity(ctx, params);
+ if (params.perplexity_lines) {
+ perplexity_lines(ctx, params);
+ } else {
+ perplexity(ctx, params);
+ }
llama_print_timings(ctx);
llama_free(ctx);