aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp91
-rw-r--r--examples/common.h21
-rw-r--r--examples/main/main.cpp71
-rw-r--r--examples/save-load-state/save-load-state.cpp34
4 files changed, 180 insertions, 37 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 9f10dc2..6c712c7 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -6,6 +6,8 @@
#include <string>
#include <iterator>
#include <algorithm>
+#include <sstream>
+#include <iostream>
#if defined (_WIN32)
#include <fcntl.h>
@@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.temp = std::stof(argv[i]);
+ } else if (arg == "--tfs") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.tfs_z = std::stof(argv[i]);
+ } else if (arg == "--typical") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat_last_n") {
if (++i >= argc) {
invalid_param = true;
@@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.repeat_penalty = std::stof(argv[i]);
+ } else if (arg == "--frequency_penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.frequency_penalty = std::stof(argv[i]);
+ } else if (arg == "--presence_penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.presence_penalty = std::stof(argv[i]);
+ } else if (arg == "--mirostat") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat = std::stoi(argv[i]);
+ } else if (arg == "--mirostat_lr") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_eta = std::stof(argv[i]);
+ } else if (arg == "--mirostat_ent") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") {
if (++i >= argc) {
invalid_param = true;
@@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
- params.ignore_eos = true;
+ params.logit_bias[llama_token_eos()] = -INFINITY;
+ } else if (arg == "--no-penalize-nl") {
+ params.penalize_nl = false;
+ } else if (arg == "-l" || arg == "--logit-bias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::stringstream ss(argv[i]);
+ llama_token key;
+ char sign;
+ std::string value_str;
+ try {
+ if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+ params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ } else {
+ throw std::exception();
+ }
+ } catch (const std::exception &e) {
+ invalid_param = true;
+ break;
+ }
} else if (arg == "--n_parts") {
if (++i >= argc) {
invalid_param = true;
@@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
- fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
- fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
- fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
- fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
+ fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
+ fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
+ fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
+ fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
+ fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+ fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+ fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
+ fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+ fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
+ fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
+ fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
+ fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
+ fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+ fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
+ fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
+ fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
+ fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
- fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
+ fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
+ fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
diff --git a/examples/common.h b/examples/common.h
index 9d3697d..14e6b1b 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -8,6 +8,7 @@
#include <vector>
#include <random>
#include <thread>
+#include <unordered_map>
//
// CLI argument parsing
@@ -17,17 +18,25 @@ struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 128; // new tokens to predict
- int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
// sampling parameters
- int32_t top_k = 40;
- float top_p = 0.95f;
- float temp = 0.80f;
- float repeat_penalty = 1.10f;
+ std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+ int32_t top_k = 0; // <= 0 to use vocab size
+ float top_p = 1.0f; // 1.0 = disabled
+ float tfs_z = 1.0f; // 1.0 = disabled
+ float typical_p = 1.0f; // 1.0 = disabled
+ float temp = 1.0f; // 1.0 = disabled
+ float repeat_penalty = 1.0f; // 1.0 = disabled
+ int32_t repeat_last_n = -1; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float frequency_penalty = 0.0f; // 0.0 = disabled
+ float presence_penalty = 0.0f; // 0.0 = disabled
+ int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.0f; // target entropy
+ float mirostat_eta = 0.1f; // learning rate
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
@@ -47,7 +56,7 @@ struct gpt_params {
bool interactive_first = false; // wait for user input immediately
bool instruct = false; // instruction mode (used for Alpaca models)
- bool ignore_eos = false; // do not stop generating after eos
+ bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index fda6557..674920b 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -276,8 +276,8 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}
- fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
- params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
+ fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
+ params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
@@ -387,10 +387,19 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
- const int32_t top_k = params.top_k;
- const float top_p = params.top_p;
const float temp = params.temp;
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
+ const float top_p = params.top_p;
+ const float tfs_z = params.tfs_z;
+ const float typical_p = params.typical_p;
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
+ const float alpha_presence = params.presence_penalty;
+ const float alpha_frequency = params.frequency_penalty;
+ const int mirostat = params.mirostat;
+ const float mirostat_tau = params.mirostat_tau;
+ const float mirostat_eta = params.mirostat_eta;
+ const bool penalize_nl = params.penalize_nl;
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
@@ -402,14 +411,58 @@ int main(int argc, char ** argv) {
{
auto logits = llama_get_logits(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
- if (params.ignore_eos) {
- logits[llama_token_eos()] = 0;
+ // Apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
}
- id = llama_sample_top_p_top_k(ctx,
- last_n_tokens.data() + n_ctx - params.repeat_last_n,
- params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+
+ // Apply penalties
+ float nl_logit = logits[llama_token_nl()];
+ auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+ llama_sample_repetition_penalty(ctx, &candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, repeat_penalty);
+ llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, alpha_frequency, alpha_presence);
+ if (!penalize_nl) {
+ logits[llama_token_nl()] = nl_logit;
+ }
+
+ if (temp <= 0) {
+ // Greedy sampling
+ id = llama_sample_token_greedy(ctx, &candidates_p);
+ } else {
+ if (mirostat == 1) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ const int mirostat_m = 100;
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+ } else if (mirostat == 2) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+ } else {
+ // Temperature sampling
+ llama_sample_top_k(ctx, &candidates_p, top_k);
+ llama_sample_tail_free(ctx, &candidates_p, tfs_z);
+ llama_sample_typical(ctx, &candidates_p, typical_p);
+ llama_sample_top_p(ctx, &candidates_p, top_p);
+ llama_sample_temperature(ctx, &candidates_p, temp);
+ id = llama_sample_token(ctx, &candidates_p);
+ }
+ }
+ // printf("`%d`", candidates_p.size);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 39aa7f8..07dfa2c 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
// first run
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx, &candidates_p);
auto next_token_str = llama_token_to_str(ctx, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx2,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx2);
+ auto n_vocab = llama_n_vocab(ctx2);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx2, &candidates_p);
auto next_token_str = llama_token_to_str(ctx2, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);