From 70bc0b8b15b98dca23b28f0c8f5e34b27e424cda Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Mar 2023 23:46:39 +0200 Subject: Fix a bug in the rope calculation --- utils.cpp | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) (limited to 'utils.cpp') diff --git a/utils.cpp b/utils.cpp index cd9c001..6a38764 100644 --- a/utils.cpp +++ b/utils.cpp @@ -257,7 +257,7 @@ std::vector llama_tokenize(const gpt_vocab & vocab, const std::st } } - if (l == 0 && t != 13) { + if (l == 0) { break; } @@ -367,6 +367,83 @@ gpt_vocab::id gpt_sample_top_k_top_p( return logits_id[idx].second; } +gpt_vocab::id llama_sample_top_p( + const gpt_vocab & vocab, + const float * logits, + double top_p, + double temp, + std::mt19937 & rng) { + int n_logits = vocab.id_to_token.size(); + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const double scale = 1.0/temp; + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } + } + + std::sort( + logits_id.begin(), + logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < (int) probs.size(); i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + probs.resize(i + 1); + logits_id.resize(i + 1); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + + //printf("\n"); + //for (int i = 0; i < (int) 10; i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + //} + //printf("\n\n"); + //exit(0); + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; +} + + size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) { const int nb = k / qk; const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2); -- cgit v1.2.3