diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-03-10 23:46:39 +0200 |
---|---|---|
committer | Georgi Gerganov <ggerganov@gmail.com> | 2023-03-10 23:46:57 +0200 |
commit | 70bc0b8b15b98dca23b28f0c8f5e34b27e424cda (patch) | |
tree | b5d02edef24e8218d28a57c7ee4a6e2208759f15 /utils.cpp | |
parent | 18ebda34d67c05f4f5584a9209e7efb949f5fd56 (diff) |
Fix a bug in the rope calculation
Diffstat (limited to 'utils.cpp')
-rw-r--r-- | utils.cpp | 79 |
1 files changed, 78 insertions, 1 deletions
@@ -257,7 +257,7 @@ std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::st } } - if (l == 0 && t != 13) { + if (l == 0) { break; } @@ -367,6 +367,83 @@ gpt_vocab::id gpt_sample_top_k_top_p( return logits_id[idx].second; } +gpt_vocab::id llama_sample_top_p( + const gpt_vocab & vocab, + const float * logits, + double top_p, + double temp, + std::mt19937 & rng) { + int n_logits = vocab.id_to_token.size(); + + std::vector<std::pair<double, gpt_vocab::id>> logits_id; + logits_id.reserve(n_logits); + + { + const double scale = 1.0/temp; + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } + } + + std::sort( + logits_id.begin(), + logits_id.end(), + [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) { + return a.first > b.first; + }); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector<double> probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < (int) probs.size(); i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + probs.resize(i + 1); + logits_id.resize(i + 1); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + + //printf("\n"); + //for (int i = 0; i < (int) 10; i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + //} + //printf("\n\n"); + //exit(0); + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; +} + + size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) { const int nb = k / qk; const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2); |