aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorklosax <131523366+klosax@users.noreply.github.com>2023-08-07 19:07:19 +0200
committerGitHub <noreply@github.com>2023-08-07 19:07:19 +0200
commitf3c3b4b1672d860800639c87d3b5d17564692469 (patch)
treeebdc8a40a9e374eb4713da9c6233c8c499cd768a
parent93356bdb7a324a8f6570f99d02af392cd4c45796 (diff)
Add --rope-scale parameter (#2544)
* common.cpp : Add --rope-scale parameter * README.md : Add info about using linear rope scaling
-rw-r--r--examples/common.cpp11
-rw-r--r--examples/main/README.md6
2 files changed, 15 insertions, 2 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 21f4a03..4d3ba9b 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
+ } else if (arg == "--rope-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
@@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
- fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
- fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
+ fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
+ fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
+ fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
diff --git a/examples/main/README.md b/examples/main/README.md
index 014112e..55c1609 100644
--- a/examples/main/README.md
+++ b/examples/main/README.md
@@ -140,6 +140,12 @@ The `--ctx-size` option allows you to set the size of the prompt context used by
- `-c N, --ctx-size N`: Set the size of the prompt context (default: 512). The LLaMA models were built with a context of 2048, which will yield the best results on longer input/inference. However, increasing the context size beyond 2048 may lead to unpredictable results.
+### Extended Context Size
+
+Some fine-tuned models have extened the context length by scaling RoPE. For example, if the original pretrained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
+
+- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.
+
### Keep Prompt
The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained.