diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 69 |
1 files changed, 62 insertions, 7 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3bf9859..043e497 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,8 +2,6 @@ #include "llama.h" #include "build-info.h" -// single thread -#define CPPHTTPLIB_THREAD_POOL_COUNT 1 #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 @@ -12,6 +10,11 @@ #include "httplib.h" #include "json.hpp" +// auto generated files (update with ./deps.sh) +#include "index.html.hpp" +#include "index.js.hpp" +#include "completion.js.hpp" + #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 #endif @@ -21,6 +24,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -172,6 +176,12 @@ struct llama_server_context { std::string stopping_word; int32_t multibyte_pending = 0; + std::mutex mutex; + + std::unique_lock<std::mutex> lock() { + return std::unique_lock<std::mutex>(mutex); + } + ~llama_server_context() { if (ctx) { llama_free(ctx); @@ -539,6 +549,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); + fprintf(stderr, " --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); fprintf(stderr, "\n"); @@ -565,6 +576,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } sparams.hostname = argv[i]; + } else if (arg == "--path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.public_path = argv[i]; } else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) { invalid_param = true; @@ -839,17 +856,24 @@ static void parse_options_completion(const json & body, llama_server_context & l LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } + static void log_server_request(const Request & req, const Response & res) { LOG_INFO("request", { { "remote_addr", req.remote_addr }, { "remote_port", req.remote_port }, { "status", res.status }, + { "method", req.method }, { "path", req.path }, + { "params", req.params }, + }); + + LOG_VERBOSE("request", { { "request", req.body }, { "response", res.body }, }); } + int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -884,16 +908,34 @@ int main(int argc, char ** argv) { Server svr; svr.set_default_headers({ + { "Server", "llama.cpp" }, { "Access-Control-Allow-Origin", "*" }, { "Access-Control-Allow-Headers", "content-type" } }); + // this is only called if no index.js is found in the public --path + svr.Get("/index.js", [](const Request &, Response & res) { + res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript"); + return false; + }); + + // this is only called if no index.html is found in the public --path svr.Get("/", [](const Request &, Response & res) { - res.set_content("<h1>llama.cpp server works</h1>", "text/html"); + res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html"); + return false; + }); + + // this is only called if no index.html is found in the public --path + svr.Get("/completion.js", [](const Request &, Response & res) { + res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript"); + return false; }); svr.Post("/completion", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + llama.rewind(); + llama_reset_timings(llama.ctx); parse_options_completion(json::parse(req.body), llama); @@ -1002,6 +1044,8 @@ int main(int argc, char ** argv) { }); svr.Post("/tokenize", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + const json body = json::parse(req.body); const std::string content = body.value("content", ""); const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false); @@ -1010,6 +1054,8 @@ int main(int argc, char ** argv) { }); svr.Post("/embedding", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + const json body = json::parse(req.body); llama.rewind(); @@ -1040,18 +1086,27 @@ int main(int argc, char ** argv) { res.status = 500; }); + svr.set_error_handler([](const Request &, Response & res) { + res.set_content("File Not Found", "text/plain"); + res.status = 404; + }); + + // set timeouts and change hostname and port svr.set_read_timeout(sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); if (!svr.bind_to_port(sparams.hostname, sparams.port)) { - LOG_ERROR("couldn't bind to server socket", { - { "hostname", sparams.hostname }, - { "port", sparams.port }, - }); + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } + // Set the base directory for serving static files + svr.set_base_dir(sparams.public_path); + + // to make it ctrl+clickable: + fprintf(stdout, "\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + LOG_INFO("HTTP server listening", { { "hostname", sparams.hostname }, { "port", sparams.port }, |