aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llama.cpp7
-rw-r--r--tests/test-sampling.cpp1
2 files changed, 5 insertions, 3 deletions
diff --git a/llama.cpp b/llama.cpp
index a528eef..ac22a48 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
for (size_t i = 0; i < candidates->size; ++i) {
cum_sum += candidates->data[i].p;
- // Check if the running sum is greater than p or if we have kept at least min_keep tokens
- if (cum_sum > p && i >= min_keep) {
- last_idx = i;
+ // Check if the running sum is at least p or if we have kept at least min_keep tokens
+ // we set the last index to i+1 to indicate that the current iterate should be included in the set
+ if (cum_sum >= p && i + 1 >= min_keep) {
+ last_idx = i + 1;
break;
}
}
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index 5d693f7..64f9455 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -181,6 +181,7 @@ int main(void) {
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
+ test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);