aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/common.cpp6
-rw-r--r--examples/common.h1
-rw-r--r--examples/server/server.cpp16
3 files changed, 21 insertions, 2 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 1308f84..478dbaf 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -251,6 +251,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.model = argv[i];
+ } else if (arg == "-a" || arg == "--alias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model_alias = argv[i];
} else if (arg == "--lora") {
if (++i >= argc) {
invalid_param = true;
diff --git a/examples/common.h b/examples/common.h
index 2b66382..fea9aa8 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -45,6 +45,7 @@ struct gpt_params {
float mirostat_eta = 0.10f; // learning rate
std::string model = "models/7B/ggml-model.bin"; // model path
+ std::string model_alias = "unknown"; // model alias
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 3904412..9eacc92 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -400,8 +400,10 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
- fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n");
- fprintf(stderr, " -port PORT port to listen (default 8080)\n");
+ fprintf(stderr, " -a ALIAS, --alias ALIAS\n");
+ fprintf(stderr, " set an alias for the model, will be added as `model` field in completion response\n");
+ fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n");
+ fprintf(stderr, " --port PORT port to listen (default 8080)\n");
fprintf(stderr, "\n");
}
@@ -453,6 +455,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
}
params.model = argv[i];
}
+ else if (arg == "-a" || arg == "--alias")
+ {
+ if (++i >= argc)
+ {
+ invalid_param = true;
+ break;
+ }
+ params.model_alias = argv[i];
+ }
else if (arg == "--embedding")
{
params.embedding = true;
@@ -645,6 +656,7 @@ int main(int argc, char **argv)
try
{
json data = {
+ {"model", llama.params.model_alias },
{"content", llama.generated_text },
{"tokens_predicted", llama.num_tokens_predicted}};
return res.set_content(data.dump(), "application/json");