aboutsummaryrefslogtreecommitdiff
path: root/examples/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/common.cpp')
-rw-r--r--examples/common.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 779605f..0e88a12 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gqa = std::stoi(argv[i]);
+ } else if (arg == "-eps" || arg == "--rms-norm-eps") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+ fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
+ lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;