aboutsummaryrefslogtreecommitdiff
path: root/examples/save-load-state
diff options
context:
space:
mode:
authorIvan Stepanov <ivanstepanovftw@gmail.com>2023-04-29 08:34:41 +0300
committerGitHub <noreply@github.com>2023-04-29 08:34:41 +0300
commitdd7eff57d8491792010b1002b8de6a4b54912e5c (patch)
treeed7f7c85ef220cafca40976b52bfeac948b3c673 /examples/save-load-state
parent7fc50c051ae8a78e9643fdf172d12e20f2dd9b6c (diff)
llama : new sampling algorithms (#1126)
* Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. * mirostat * Added --logit-bias and --no-penalize-nl, removed std::span * Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) * Save and load example adjust * Tests * Windows build fix * Windows test fix
Diffstat (limited to 'examples/save-load-state')
-rw-r--r--examples/save-load-state/save-load-state.cpp34
1 files changed, 18 insertions, 16 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 39aa7f8..07dfa2c 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
// first run
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx, &candidates_p);
auto next_token_str = llama_token_to_str(ctx, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx2,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx2);
+ auto n_vocab = llama_n_vocab(ctx2);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx2, &candidates_p);
auto next_token_str = llama_token_to_str(ctx2, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);