aboutsummaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp69
1 files changed, 62 insertions, 7 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 3bf9859..043e497 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -2,8 +2,6 @@
#include "llama.h"
#include "build-info.h"
-// single thread
-#define CPPHTTPLIB_THREAD_POOL_COUNT 1
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
#define CPPHTTPLIB_NO_EXCEPTIONS 1
@@ -12,6 +10,11 @@
#include "httplib.h"
#include "json.hpp"
+// auto generated files (update with ./deps.sh)
+#include "index.html.hpp"
+#include "index.js.hpp"
+#include "completion.js.hpp"
+
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif
@@ -21,6 +24,7 @@ using json = nlohmann::json;
struct server_params {
std::string hostname = "127.0.0.1";
+ std::string public_path = "examples/server/public";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
@@ -172,6 +176,12 @@ struct llama_server_context {
std::string stopping_word;
int32_t multibyte_pending = 0;
+ std::mutex mutex;
+
+ std::unique_lock<std::mutex> lock() {
+ return std::unique_lock<std::mutex>(mutex);
+ }
+
~llama_server_context() {
if (ctx) {
llama_free(ctx);
@@ -539,6 +549,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
+ fprintf(stderr, " --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
fprintf(stderr, "\n");
@@ -565,6 +576,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
break;
}
sparams.hostname = argv[i];
+ } else if (arg == "--path") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.public_path = argv[i];
} else if (arg == "--timeout" || arg == "-to") {
if (++i >= argc) {
invalid_param = true;
@@ -839,17 +856,24 @@ static void parse_options_completion(const json & body, llama_server_context & l
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
+
static void log_server_request(const Request & req, const Response & res) {
LOG_INFO("request", {
{ "remote_addr", req.remote_addr },
{ "remote_port", req.remote_port },
{ "status", res.status },
+ { "method", req.method },
{ "path", req.path },
+ { "params", req.params },
+ });
+
+ LOG_VERBOSE("request", {
{ "request", req.body },
{ "response", res.body },
});
}
+
int main(int argc, char ** argv) {
// own arguments required by this example
gpt_params params;
@@ -884,16 +908,34 @@ int main(int argc, char ** argv) {
Server svr;
svr.set_default_headers({
+ { "Server", "llama.cpp" },
{ "Access-Control-Allow-Origin", "*" },
{ "Access-Control-Allow-Headers", "content-type" }
});
+ // this is only called if no index.js is found in the public --path
+ svr.Get("/index.js", [](const Request &, Response & res) {
+ res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript");
+ return false;
+ });
+
+ // this is only called if no index.html is found in the public --path
svr.Get("/", [](const Request &, Response & res) {
- res.set_content("<h1>llama.cpp server works</h1>", "text/html");
+ res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html");
+ return false;
+ });
+
+ // this is only called if no index.html is found in the public --path
+ svr.Get("/completion.js", [](const Request &, Response & res) {
+ res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript");
+ return false;
});
svr.Post("/completion", [&llama](const Request & req, Response & res) {
+ auto lock = llama.lock();
+
llama.rewind();
+
llama_reset_timings(llama.ctx);
parse_options_completion(json::parse(req.body), llama);
@@ -1002,6 +1044,8 @@ int main(int argc, char ** argv) {
});
svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
+ auto lock = llama.lock();
+
const json body = json::parse(req.body);
const std::string content = body.value("content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
@@ -1010,6 +1054,8 @@ int main(int argc, char ** argv) {
});
svr.Post("/embedding", [&llama](const Request & req, Response & res) {
+ auto lock = llama.lock();
+
const json body = json::parse(req.body);
llama.rewind();
@@ -1040,18 +1086,27 @@ int main(int argc, char ** argv) {
res.status = 500;
});
+ svr.set_error_handler([](const Request &, Response & res) {
+ res.set_content("File Not Found", "text/plain");
+ res.status = 404;
+ });
+
+
// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
- LOG_ERROR("couldn't bind to server socket", {
- { "hostname", sparams.hostname },
- { "port", sparams.port },
- });
+ fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}
+ // Set the base directory for serving static files
+ svr.set_base_dir(sparams.public_path);
+
+ // to make it ctrl+clickable:
+ fprintf(stdout, "\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
+
LOG_INFO("HTTP server listening", {
{ "hostname", sparams.hostname },
{ "port", sparams.port },