diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 12d4e2f..c0984aa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -254,6 +254,11 @@ struct llama_server_context { n_past += n_eval; } + if (params.n_predict == 0) { + has_next_token = false; + return llama_token_eos(); + } + // out of user input, sample next token const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; @@ -419,6 +424,19 @@ struct llama_server_context { return token_text; } + + std::vector<float> getEmbedding() { + static const int n_embd = llama_n_embd(ctx); + if (!params.embedding) { + LOG_WARNING("embedding disabled", { + { "params.embedding", params.embedding }, + }); + return std::vector<float>(n_embd, 0.0f); + } + const float * data = llama_get_embeddings(ctx); + std::vector<float> embedding(data, data + n_embd); + return embedding; + } }; static void server_print_usage(const char * argv0, const gpt_params & params, @@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); + fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); fprintf(stderr, "\n"); } @@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.use_mlock = true; } else if (arg == "--no-mmap") { params.use_mmap = false; + } else if (arg == "--embedding") { + params.embedding = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); @@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) { }; } +static json format_embedding_response(llama_server_context & llama) { + return json { + { "embedding", llama.getEmbedding() }, + }; +} + static json format_final_response(llama_server_context & llama, const std::string & content) { return json { { "content", content }, @@ -881,12 +908,27 @@ int main(int argc, char ** argv) { svr.Post("/tokenize", [&llama](const Request & req, Response & res) { const json body = json::parse(req.body); - const std::string content = body["content"].get<std::string>(); + const std::string content = body.value("content", ""); const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false); const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json"); }); + svr.Post("/embedding", [&llama](const Request & req, Response & res) { + const json body = json::parse(req.body); + + llama.rewind(); + llama_reset_timings(llama.ctx); + llama.params.prompt = body.value("content", ""); + llama.params.n_predict = 0; + llama.loadPrompt(); + llama.beginCompletion(); + llama.doCompletion(); + + const json data = format_embedding_response(llama); + return res.set_content(data.dump(), "application/json"); + }); + svr.set_logger(log_server_request); svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) { |