diff options
author | klosax <131523366+klosax@users.noreply.github.com> | 2023-07-22 14:21:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 14:21:24 +0200 |
commit | b5fe67f8c69113bd9354bc1adcfe2df6be323740 (patch) | |
tree | 705c322554824091c35e885214ae19462773eab0 /examples | |
parent | 24baa54ac1ff3d4156a2360deb1473af04a9b1a2 (diff) |
Perplexity: Compute scores correlated to HellaSwag (#2312)
* Add parameter --perplexity-lines to perplexity.cpp
Diffstat (limited to 'examples')
-rw-r--r-- | examples/common.cpp | 5 | ||||
-rw-r--r-- | examples/common.h | 1 | ||||
-rw-r--r-- | examples/perplexity/perplexity.cpp | 78 |
3 files changed, 82 insertions, 2 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index 0990195..730b28b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -387,6 +387,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; + } else if (arg == "--perplexity-lines") { + params.perplexity_lines = true; } else if (arg == "--ignore-eos") { params.logit_bias[llama_token_eos()] = -INFINITY; } else if (arg == "--no-penalize-nl") { @@ -512,7 +514,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); + fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n"); + fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n"); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); if (llama_mlock_supported()) { diff --git a/examples/common.h b/examples/common.h index 69170df..c936de6 100644 --- a/examples/common.h +++ b/examples/common.h @@ -82,6 +82,7 @@ struct gpt_params { bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token bool perplexity = false; // compute perplexity over the prompt + bool perplexity_lines = false; // compute perplexity over each line of the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool mem_test = false; // compute maximum memory usage diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bfad999..d23b7e7 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -4,6 +4,7 @@ #include <cmath> #include <ctime> +#include <sstream> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) { printf("\n"); } +void perplexity_lines(llama_context * ctx, const gpt_params & params) { + // Calculates perplexity over each line of the prompt + + std::vector<std::string> prompt_lines; + std::istringstream strstream(params.prompt); + std::string line; + + while (std::getline(strstream,line,'\n')) { + prompt_lines.push_back(line); + } + + const int n_vocab = llama_n_vocab(ctx); + + int counttotal = 0; + size_t n_lines = prompt_lines.size(); + + double nll = 0.0; + + fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines); + + printf("\nLine\tPPL line\tPPL cumulative\n"); + + for (size_t i = 0; i < n_lines; ++i) { + + // Tokenize and insert BOS at start + std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true); + + size_t batch_size = batch_embd.size(); + + // Stop if line is too long + if( batch_size > (size_t)params.n_ctx ) { + fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i); + return; + } + + if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + + const auto batch_logits = llama_get_logits(ctx); + std::vector<float> logits; + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + + double nllline = 0.0; + int countline = 0; + + // Perplexity over second half of the line + for (size_t j = batch_size/2; j < batch_size - 1; ++j) { + // Calculate probability of next token, given the previous ones. + const std::vector<float> tok_logits( + logits.begin() + (j + 0) * n_vocab, + logits.begin() + (j + 1) * n_vocab); + + const float prob = softmax(tok_logits)[batch_embd[ j + 1]]; + + nllline += -std::log(prob); + ++countline; + } + + nll += nllline; + counttotal += countline; + + // perplexity is e^(average negative log-likelihood) + printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) ); + fflush(stdout); + } + + printf("\n"); +} + int main(int argc, char ** argv) { gpt_params params; @@ -168,7 +240,11 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - perplexity(ctx, params); + if (params.perplexity_lines) { + perplexity_lines(ctx, params); + } else { + perplexity(ctx, params); + } llama_print_timings(ctx); llama_free(ctx); |