diff options
author | Ivan Stepanov <ivanstepanovftw@gmail.com> | 2023-04-29 08:34:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-29 08:34:41 +0300 |
commit | dd7eff57d8491792010b1002b8de6a4b54912e5c (patch) | |
tree | ed7f7c85ef220cafca40976b52bfeac948b3c673 /examples/save-load-state | |
parent | 7fc50c051ae8a78e9643fdf172d12e20f2dd9b6c (diff) |
llama : new sampling algorithms (#1126)
* Sample interface, new samplers.
New samplers:
- locally typical sampling
- tail free sampling
- frequency and presence penalty
- mirostat
Ignore EOS fix: -inf should be used.
* mirostat
* Added --logit-bias and --no-penalize-nl, removed std::span
* Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k)
* Save and load example adjust
* Tests
* Windows build fix
* Windows test fix
Diffstat (limited to 'examples/save-load-state')
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39aa7f8..07dfa2c 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -64,14 +64,15 @@ int main(int argc, char ** argv) { // first run printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sample_top_p_top_k( - ctx, - &last_n_tokens_data.back() - params.repeat_last_n, - params.repeat_last_n, - 40, - 1.0, - 1.0, - 1.1); + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx, &candidates_p); auto next_token_str = llama_token_to_str(ctx, next_token); last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); @@ -106,14 +107,15 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto next_token = llama_sample_top_p_top_k( - ctx2, - &last_n_tokens_data.back() - params.repeat_last_n, - params.repeat_last_n, - 40, - 1.0, - 1.0, - 1.1); + auto logits = llama_get_logits(ctx2); + auto n_vocab = llama_n_vocab(ctx2); + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx2, &candidates_p); auto next_token_str = llama_token_to_str(ctx2, next_token); last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); |