diff options
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 148 |
1 files changed, 106 insertions, 42 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2248c24..56ada7e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,8 +4,10 @@ #endif #include "common.h" +#include "console.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include <cassert> #include <cinttypes> @@ -34,9 +36,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static console_state con_st; static llama_context ** g_ctx; - static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -45,7 +45,7 @@ void sigint_handler(int signo) { if (!is_interacting) { is_interacting=true; } else { - console_cleanup(con_st); + console::cleanup(); printf("\n"); llama_print_timings(*g_ctx); _exit(130); @@ -63,10 +63,8 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) - con_st.use_color = params.use_color; - con_st.multiline_input = params.multiline_input; - console_init(con_st); - atexit([]() { console_cleanup(con_st); }); + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); if (params.perplexity) { printf("\n************\n"); @@ -84,9 +82,17 @@ int main(int argc, char ** argv) { return 0; } + if (params.rope_freq_base != 10000.0) { + fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 1.0) { + fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + } + if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);" - "expect poor results\n", __func__, params.n_ctx); + // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048 + fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx); } else if (params.n_ctx < 8) { fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; @@ -131,17 +137,14 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters + // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters // uncomment the "used_mem" line in llama.cpp to see the results if (params.mem_test) { { - const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - } + fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - { - const std::vector<llama_token> tmp = { 0, }; - llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); + const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); + llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } llama_print_timings(ctx); @@ -319,6 +322,10 @@ int main(int argc, char ** argv) { } } + if (params.input_prefix_bos) { + fprintf(stderr, "Input prefix with BOS\n"); + } + if (!params.input_prefix.empty()) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } @@ -332,13 +339,38 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector<llama_token> last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { const char *control_message; - if (con_st.multiline_input) { + if (params.multiline_input) { control_message = " - To return control to LLaMa, end your input with '\\'.\n" " - To return control without starting a new line, end your input with '/'.\n"; } else { @@ -366,7 +398,7 @@ int main(int argc, char ** argv) { int n_past_guidance = 0; // the first thing we will do is to output the prompt, so set color accordingly - console_set_color(con_st, CONSOLE_COLOR_PROMPT); + console::set_display(console::prompt); std::vector<llama_token> embd; std::vector<llama_token> embd_guidance; @@ -387,9 +419,9 @@ int main(int argc, char ** argv) { // Ensure the input doesn't exceed the context size by truncating embd if necessary. if ((int)embd.size() > max_embd_size) { auto skipped_tokens = embd.size() - max_embd_size; - console_set_color(con_st, CONSOLE_COLOR_ERROR); + console::set_display(console::error); printf("<<input too long: skipped %zu token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); fflush(stdout); embd.resize(max_embd_size); } @@ -549,7 +581,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); + llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale); } // Apply penalties @@ -565,6 +597,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -590,20 +626,14 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } - // replace end of text token with newline token when in interactive mode - if (id == llama_token_eos() && params.interactive && !params.instruct) { - id = llama_token_newline.front(); - if (params.antiprompt.size() != 0) { - // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); - embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); - } - } - // add it to the context embd.push_back(id); @@ -634,7 +664,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); } // if not currently processing queued inputs; @@ -660,7 +690,7 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (params.interactive) { is_interacting = true; - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + console::set_display(console::user_input); } is_antiprompt = true; fflush(stdout); @@ -669,11 +699,34 @@ int main(int argc, char ** argv) { } } + // deal with end of text token in interactive mode + if (last_n_tokens.back() == llama_token_eos()) { + if (params.interactive) { + if (params.antiprompt.size() != 0) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + console::set_display(console::user_input); + fflush(stdout); + } else if (params.instruct) { + is_interacting = true; + } + } + if (n_past > 0 && is_interacting) { if (params.instruct) { printf("\n> "); } + if (params.input_prefix_bos) { + embd_inp.push_back(llama_token_bos()); + } + std::string buffer; if (!params.input_prefix.empty()) { buffer += params.input_prefix; @@ -683,12 +736,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { - another_line = console_readline(con_st, line); + another_line = console::readline(line, params.multiline_input); buffer += line; } while (another_line); // done taking input, reset color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -720,18 +773,26 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector<const llama_grammar_element *> grammar_rules( + parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos()) { - if (params.instruct) { - is_interacting = true; - } else { - fprintf(stderr, " [end of text]\n"); - break; - } + if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + fprintf(stderr, " [end of text]\n"); + break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. @@ -751,6 +812,9 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + if (grammar != NULL) { + llama_grammar_free(grammar); + } llama_backend_free(); return 0; |