aboutsummaryrefslogtreecommitdiff
path: root/examples/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/common.cpp')
-rw-r--r--examples/common.cpp91
1 files changed, 85 insertions, 6 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 9f10dc2..6c712c7 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -6,6 +6,8 @@
#include <string>
#include <iterator>
#include <algorithm>
+#include <sstream>
+#include <iostream>
#if defined (_WIN32)
#include <fcntl.h>
@@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.temp = std::stof(argv[i]);
+ } else if (arg == "--tfs") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.tfs_z = std::stof(argv[i]);
+ } else if (arg == "--typical") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat_last_n") {
if (++i >= argc) {
invalid_param = true;
@@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.repeat_penalty = std::stof(argv[i]);
+ } else if (arg == "--frequency_penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.frequency_penalty = std::stof(argv[i]);
+ } else if (arg == "--presence_penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.presence_penalty = std::stof(argv[i]);
+ } else if (arg == "--mirostat") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat = std::stoi(argv[i]);
+ } else if (arg == "--mirostat_lr") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_eta = std::stof(argv[i]);
+ } else if (arg == "--mirostat_ent") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") {
if (++i >= argc) {
invalid_param = true;
@@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
- params.ignore_eos = true;
+ params.logit_bias[llama_token_eos()] = -INFINITY;
+ } else if (arg == "--no-penalize-nl") {
+ params.penalize_nl = false;
+ } else if (arg == "-l" || arg == "--logit-bias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::stringstream ss(argv[i]);
+ llama_token key;
+ char sign;
+ std::string value_str;
+ try {
+ if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+ params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ } else {
+ throw std::exception();
+ }
+ } catch (const std::exception &e) {
+ invalid_param = true;
+ break;
+ }
} else if (arg == "--n_parts") {
if (++i >= argc) {
invalid_param = true;
@@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
- fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
- fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", (double)params.top_p);
- fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
- fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
+ fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
+ fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
+ fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
+ fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
+ fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+ fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+ fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
+ fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+ fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
+ fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
+ fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
+ fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
+ fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+ fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
+ fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
+ fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
+ fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
- fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
+ fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
+ fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");