aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJed Fox <git@jedfox.com>2023-05-06 17:01:47 -0400
committerGitHub <noreply@github.com>2023-05-06 17:01:47 -0400
commit3924088512d9e12e90ed6dbf28a6c5712481d33e (patch)
tree9dacd76924e57e792ab834b47d77b875d1c8ae4f
parent173d0e6419e8f8f3c1f4f13201b777f4c60629f3 (diff)
Remove default arguments from sampling functions (#1343)
-rw-r--r--.gitignore1
-rw-r--r--examples/main/main.cpp8
-rw-r--r--llama.cpp2
-rw-r--r--llama.h8
-rw-r--r--tests/test-sampling.cpp8
5 files changed, 14 insertions, 13 deletions
diff --git a/.gitignore b/.gitignore
index e479c61..6f275fe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,6 +21,7 @@ build-sanitize-addr/
build-sanitize-thread/
models/*
+*.bin
/main
/quantize
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 43dca8e..5ac151e 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -444,10 +444,10 @@ int main(int argc, char ** argv) {
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
- llama_sample_top_k(ctx, &candidates_p, top_k);
- llama_sample_tail_free(ctx, &candidates_p, tfs_z);
- llama_sample_typical(ctx, &candidates_p, typical_p);
- llama_sample_top_p(ctx, &candidates_p, top_p);
+ llama_sample_top_k(ctx, &candidates_p, top_k, 1);
+ llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
+ llama_sample_typical(ctx, &candidates_p, typical_p, 1);
+ llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
diff --git a/llama.cpp b/llama.cpp
index 85af4dc..c36c6ce 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1791,7 +1791,7 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
- llama_sample_top_k(nullptr, candidates, int(k));
+ llama_sample_top_k(nullptr, candidates, int(k), 1);
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
diff --git a/llama.h b/llama.h
index e993c46..58c6e06 100644
--- a/llama.h
+++ b/llama.h
@@ -202,16 +202,16 @@ extern "C" {
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
- LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
+ LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
- LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+ LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
- LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
+ LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
- LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+ LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index 8ce59af..9174c1e 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -32,7 +32,7 @@ void test_top_k(const std::vector<float> & probs,
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);
- llama_sample_top_k(nullptr, &candidates_p, k);
+ llama_sample_top_k(nullptr, &candidates_p, k, 1);
DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());
@@ -57,7 +57,7 @@ void test_top_p(const std::vector<float> & probs,
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);
- llama_sample_top_p(nullptr, &candidates_p, p);
+ llama_sample_top_p(nullptr, &candidates_p, p, 1);
DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());
@@ -80,7 +80,7 @@ void test_tfs(const std::vector<float> & probs,
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
- llama_sample_tail_free(nullptr, &candidates_p, z);
+ llama_sample_tail_free(nullptr, &candidates_p, z, 1);
DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());
@@ -103,7 +103,7 @@ void test_typical(const std::vector<float> & probs,
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
DUMP(&candidates_p);
- llama_sample_typical(nullptr, &candidates_p, p);
+ llama_sample_typical(nullptr, &candidates_p, p, 1);
DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());