aboutsummaryrefslogtreecommitdiff
path: root/examples/server
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-07-25 11:36:17 +0200
committerGitHub <noreply@github.com>2023-07-25 12:36:17 +0300
commitd5512b782b27ff698007dcd175da18959d5f163f (patch)
tree4be7d33b53e51dd99df8fd1e55420ef55447d6d8 /examples/server
parentc798308e3a425eae050a1f249a576fa8c6433327 (diff)
server: add rms_norm_eps parameter (#2380)
Diffstat (limited to 'examples/server')
-rw-r--r--examples/server/server.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 4ad0ba9..83c0306 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+ fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_gqa = std::stoi(argv[i]);
}
+ else if (arg == "-eps" || arg == "--rms-norm-eps") {
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ params.rms_norm_eps = std::stof(argv[i]);
+ }
else if (arg == "--rope-freq-base")
{
if (++i >= argc)