aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorIvan Stepanov <ivanstepanovftw@gmail.com>2023-04-07 19:02:12 +0300
committerGitHub <noreply@github.com>2023-04-07 19:02:12 +0300
commit4953e9007f86327aabc8312a7211c18019a3a40e (patch)
tree1419caafb0d1ebbdd3f05d36e461e8ce10a2edc5 /llama.cpp
parentcc9cee8e9e7598bd280295f6264f36d3a9224006 (diff)
llama : always sort logits before nucleus sampling (#812)
* Always sort logits before nucleus sampling * remove second normalization - fix windows build - remove normalization since std::discrete_distribution does not require it
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp17
1 files changed, 3 insertions, 14 deletions
diff --git a/llama.cpp b/llama.cpp
index 581a839..978327a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1236,19 +1236,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
}
}
- if (top_k > 0 && top_k < n_logits) {
- sample_top_k(logits_id, top_k);
- }
-
- float maxl = -std::numeric_limits<float>::infinity();
- for (const auto & kv : logits_id) {
- maxl = Max(maxl, kv.first);
- }
+ sample_top_k(logits_id, top_k > 0 ? Min(top_k, n_logits) : n_logits);
// compute probs for the top k tokens
std::vector<float> probs;
probs.reserve(logits_id.size());
+ float maxl = logits_id[0].first;
double sum = 0.0;
for (const auto & kv : logits_id) {
const float p = expf(kv.first - maxl);
@@ -1271,16 +1265,11 @@ static llama_vocab::id llama_sample_top_p_top_k(
break;
}
}
-
- cumsum = 1.0/cumsum;
- for (int i = 0; i < (int) probs.size(); i++) {
- probs[i] *= cumsum;
- }
}
//printf("\n");
//for (int i = 0; i < (int) 10; i++) {
- // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+ // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]);
//}
//printf("\n\n");
//exit(0);