diff options
author | WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> | 2023-07-03 05:38:44 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-03 00:38:44 +0300 |
commit | d7d2e6a0f0c74f7a570dae384dfff371ac744d2a (patch) | |
tree | ad82a8c9b71b5375936062ca2fc89fc012af10ed /examples/server | |
parent | 46088f72318981341a2d646f12f6eee6aec06d65 (diff) |
server: add option to output probabilities for completion (#1962)
* server: add option to output probabilities for completion
* server: fix issue when handling probability output for incomplete tokens for multibyte character generation
* server: fix llama_sample_top_k order
* examples/common.h: put all bool variables in gpt_params together
Diffstat (limited to 'examples/server')
-rw-r--r-- | examples/server/server.cpp | 150 |
1 files changed, 120 insertions, 30 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 998d55e..e4ddbe9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -26,6 +26,17 @@ struct server_params { int32_t write_timeout = 600; }; +// completion token output with probabilities +struct completion_token_output { + struct token_prob { + llama_token tok; + float prob; + }; + + std::vector<token_prob> probs; + llama_token tok; +}; + static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) { size_t i; for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} @@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line, fflush(stdout); } +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { + std::string out = token == -1 ? "" : llama_token_to_str(ctx, token); + // if first bit is 1, meaning it's a partial character + if (out.size() > 0 && (out[0] & 0x80) == 0x80) { + std::stringstream ss; + ss<< std::hex << (out[0] & 0xff); + std::string res ( ss.str() ); + out = "byte: \\x" + res; + } + return out; +} + +// convert a vector of completion_token_output to json +static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> probs) { + json out = json::array(); + for (const auto & prob : probs) { + json probs_for_token = json::array(); + for (const auto & p : prob.probs) { + std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); + probs_for_token.push_back(json { + { "tok_str", tok_str }, + { "prob", p.prob }, + }); + } + std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); + out.push_back(json { + {"content", tok_str}, + {"probs", probs_for_token}, + }); + } + return out; +} + static bool server_verbose = false; #if SERVER_VERBOSE != 1 @@ -107,6 +152,7 @@ struct llama_server_context { bool stream = false; bool has_next_token = false; std::string generated_text; + std::vector<completion_token_output> generated_token_probs; size_t num_tokens_predicted = 0; size_t n_past = 0; @@ -142,6 +188,7 @@ struct llama_server_context { num_tokens_predicted = 0; generated_text = ""; generated_text.reserve(params.n_ctx); + generated_token_probs.clear(); truncated = false; stopped_eos = false; stopped_word = false; @@ -221,8 +268,9 @@ struct llama_server_context { llama_set_rng_seed(ctx, params.seed); } - llama_token nextToken() { - llama_token result = -1; + completion_token_output nextToken() { + completion_token_output result; + result.tok = -1; if (embd.size() >= (size_t)params.n_ctx) { // Reset context @@ -261,7 +309,8 @@ struct llama_server_context { if (params.n_predict == 0) { has_next_token = false; - return llama_token_eos(); + result.tok = llama_token_eos(); + return result; } // out of user input, sample next token @@ -278,7 +327,7 @@ struct llama_server_context { const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; - llama_token id = 0; + const int32_t n_probs = params.n_probs; { auto * logits = llama_get_logits(ctx); @@ -312,35 +361,42 @@ struct llama_server_context { if (temp <= 0) { // Greedy sampling - id = llama_sample_token_greedy(ctx, &candidates_p); + result.tok = llama_sample_token_greedy(ctx, &candidates_p); + if (n_probs > 0) { + llama_sample_softmax(ctx, &candidates_p); + } } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); + size_t min_keep = std::max(1, n_probs); + llama_sample_top_k(ctx, &candidates_p, top_k, min_keep); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); + llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); + llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token(ctx, &candidates_p); + result.tok = llama_sample_token(ctx, &candidates_p); } } + + for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); + last_n_tokens.push_back(result.tok); num_tokens_predicted++; } // add it to the context - embd.push_back(id); - result = id; + embd.push_back(result.tok); // decrement remaining sampling budget --n_remain; @@ -382,12 +438,16 @@ struct llama_server_context { return stop_pos; } - std::string doCompletion() { - const llama_token token = nextToken(); + completion_token_output doCompletion() { + const completion_token_output token_with_probs = nextToken(); - const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok); generated_text += token_text; + if (params.n_probs > 0) { + generated_token_probs.push_back(token_with_probs); + } + if (multibyte_pending > 0) { multibyte_pending -= token_text.size(); } else if (token_text.size() == 1) { @@ -416,8 +476,8 @@ struct llama_server_context { } LOG_VERBOSE("next token", { - { "token", token }, - { "token_text", llama_token_to_str(ctx, token) }, + { "token", token_with_probs.tok }, + { "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) }, { "has_next_token", has_next_token }, { "n_remain", n_remain }, { "num_tokens_predicted", num_tokens_predicted }, @@ -427,7 +487,7 @@ struct llama_server_context { { "stopping_word", stopping_word }, }); - return token_text; + return token_with_probs; } std::vector<float> getEmbedding() { @@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) { { "ignore_eos", ignore_eos }, { "stream", llama.stream }, { "logit_bias", llama.params.logit_bias }, + { "n_probs", llama.params.n_probs }, }; } @@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) { }; } -static json format_final_response(llama_server_context & llama, const std::string & content) { - return json { +static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) { + + json res = json { { "content", content }, { "stop", true }, { "model", llama.params.model_alias }, @@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin { "stopped_limit", llama.stopped_limit }, { "stopping_word", llama.stopping_word }, }; + + if (llama.params.n_probs > 0) { + res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); + } + + return res; } -static json format_partial_response(const std::string & content) { - return json { +static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) { + json res = json { { "content", content }, { "stop", false }, }; + + if (llama.params.n_probs > 0) { + res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); + } + + return res; } static json format_tokenizer_response(const std::vector<llama_token> & tokens) { @@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.prompt = body.value("prompt", default_params.prompt); + llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) { @@ -830,7 +905,8 @@ int main(int argc, char ** argv) { size_t stop_pos = std::string::npos; while (llama.has_next_token) { - const std::string token_text = llama.doCompletion(); + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok); stop_pos = llama.findStoppingStrings(llama.generated_text, token_text.size(), STOP_FULL); @@ -844,7 +920,7 @@ int main(int argc, char ** argv) { llama.generated_text.end()); } - const json data = format_final_response(llama, llama.generated_text); + const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs); llama_print_timings(llama.ctx); @@ -853,9 +929,11 @@ int main(int argc, char ** argv) { } else { const auto chunked_content_provider = [&](size_t, DataSink & sink) { size_t sent_count = 0; + size_t sent_token_probs_index = 0; while (llama.has_next_token) { - const std::string token_text = llama.doCompletion(); + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok); if (llama.multibyte_pending > 0) { continue; } @@ -878,10 +956,22 @@ int main(int argc, char ** argv) { const std::string to_send = llama.generated_text.substr(pos, stop_pos); sent_count += to_send.size(); + std::vector<completion_token_output> probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + const json data = llama.has_next_token - ? format_partial_response(to_send) + ? format_partial_response(llama, to_send, probs_output) // Generation is done, send extra information. - : format_final_response(llama, to_send); + : format_final_response(llama, to_send, llama.generated_token_probs); const std::string str = "data: " + |