diff options
author | beiller <beiller@gmail.com> | 2023-03-12 16:23:15 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-12 22:23:15 +0200 |
commit | 02f0c6fe7f9b7be24c7d339aed016e54a92388ea (patch) | |
tree | 1da03a06b631ee711abeafa9ff410d87bd579f5d /utils.cpp | |
parent | eb062bb012c4e131818dd757a6d3a757fdee3961 (diff) |
Add back top_k (#56)
* Add back top_k
* Update utils.cpp
* Update utils.h
---------
Co-authored-by: Bill Hamilton <bill.hamilton@shopify.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'utils.cpp')
-rw-r--r-- | utils.cpp | 79 |
1 files changed, 4 insertions, 75 deletions
@@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { return true; } -gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const float * logits, - int top_k, - double top_p, - double temp, - std::mt19937 & rng) { - int n_logits = vocab.id_to_token.size(); - - std::vector<std::pair<double, gpt_vocab::id>> logits_id; - logits_id.reserve(n_logits); - - { - const double scale = 1.0/temp; - for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); - } - } +void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) { // find the top K tokens std::partial_sort( logits_id.begin(), @@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p( }); logits_id.resize(top_k); - - double maxl = -INFINITY; - for (const auto & kv : logits_id) { - maxl = std::max(maxl, kv.first); - } - - // compute probs for the top K tokens - std::vector<double> probs; - probs.reserve(logits_id.size()); - - double sum = 0.0; - for (const auto & kv : logits_id) { - double p = exp(kv.first - maxl); - probs.push_back(p); - sum += p; - } - - // normalize the probs - for (auto & p : probs) { - p /= sum; - } - - if (top_p < 1.0f) { - double cumsum = 0.0f; - for (int i = 0; i < top_k; i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - top_k = i + 1; - probs.resize(top_k); - logits_id.resize(top_k); - break; - } - } - - cumsum = 1.0/cumsum; - for (int i = 0; i < (int) probs.size(); i++) { - probs[i] *= cumsum; - } - } - - //printf("\n"); - //for (int i = 0; i < (int) probs.size(); i++) { - // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); - //} - //exit(0); - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); - - return logits_id[idx].second; } -gpt_vocab::id llama_sample_top_p( +gpt_vocab::id llama_sample_top_p_top_k( const gpt_vocab & vocab, const float * logits, std::vector<gpt_vocab::id> & last_n_tokens, double repeat_penalty, + int top_k, double top_p, double temp, std::mt19937 & rng) { @@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p( } } - std::sort( - logits_id.begin(), - logits_id.end(), - [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) { - return a.first > b.first; - }); + sample_top_k(logits_id, top_k); double maxl = -INFINITY; for (const auto & kv : logits_id) { |