aboutsummaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorIvan Stepanov <ivanstepanovftw@gmail.com>2023-04-29 08:34:41 +0300
committerGitHub <noreply@github.com>2023-04-29 08:34:41 +0300
commitdd7eff57d8491792010b1002b8de6a4b54912e5c (patch)
treeed7f7c85ef220cafca40976b52bfeac948b3c673 /examples/main/main.cpp
parent7fc50c051ae8a78e9643fdf172d12e20f2dd9b6c (diff)
llama : new sampling algorithms (#1126)
* Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. * mirostat * Added --logit-bias and --no-penalize-nl, removed std::span * Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) * Save and load example adjust * Tests * Windows build fix * Windows test fix
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp71
1 files changed, 62 insertions, 9 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index fda6557..674920b 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
- fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
- params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
+ fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
+ params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
@@ -387,10 +387,19 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
- const int32_t top_k = params.top_k;
- const float top_p = params.top_p;
const float temp = params.temp;
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+ const float top_p = params.top_p;
+ const float tfs_z = params.tfs_z;
+ const float typical_p = params.typical_p;
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
+ const float alpha_presence = params.presence_penalty;
+ const float alpha_frequency = params.frequency_penalty;
+ const int mirostat = params.mirostat;
+ const float mirostat_tau = params.mirostat_tau;
+ const float mirostat_eta = params.mirostat_eta;
+ const bool penalize_nl = params.penalize_nl;
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
@@ -402,14 +411,58 @@ int main(int argc, char ** argv) {
{
auto logits = llama_get_logits(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
- if (params.ignore_eos) {
- logits[llama_token_eos()] = 0;
+ // Apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
}
- id = llama_sample_top_p_top_k(ctx,
- last_n_tokens.data() + n_ctx - params.repeat_last_n,
- params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+ // Apply penalties
+ float nl_logit = logits[llama_token_nl()];
+ auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+ llama_sample_repetition_penalty(ctx, &candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, repeat_penalty);
+ llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, alpha_frequency, alpha_presence);
+ if (!penalize_nl) {
+ logits[llama_token_nl()] = nl_logit;
+ }
+
+ if (temp <= 0) {
+ // Greedy sampling
+ id = llama_sample_token_greedy(ctx, &candidates_p);
+ } else {
+ if (mirostat == 1) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ const int mirostat_m = 100;
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+ } else if (mirostat == 2) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+ } else {
+ // Temperature sampling
+ llama_sample_top_k(ctx, &candidates_p, top_k);
+ llama_sample_tail_free(ctx, &candidates_p, tfs_z);
+ llama_sample_typical(ctx, &candidates_p, typical_p);
+ llama_sample_top_p(ctx, &candidates_p, top_p);
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token(ctx, &candidates_p);
+ }
+ }
+ // printf("`%d`", candidates_p.size);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);