aboutsummaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp148
1 files changed, 106 insertions, 42 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 2248c24..56ada7e 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -4,8 +4,10 @@
#endif
#include "common.h"
+#include "console.h"
#include "llama.h"
#include "build-info.h"
+#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@@ -34,9 +36,7 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static console_state con_st;
static llama_context ** g_ctx;
-
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -45,7 +45,7 @@ void sigint_handler(int signo) {
if (!is_interacting) {
is_interacting=true;
} else {
- console_cleanup(con_st);
+ console::cleanup();
printf("\n");
llama_print_timings(*g_ctx);
_exit(130);
@@ -63,10 +63,8 @@ int main(int argc, char ** argv) {
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
- con_st.use_color = params.use_color;
- con_st.multiline_input = params.multiline_input;
- console_init(con_st);
- atexit([]() { console_cleanup(con_st); });
+ console::init(params.simple_io, params.use_color);
+ atexit([]() { console::cleanup(); });
if (params.perplexity) {
printf("\n************\n");
@@ -84,9 +82,17 @@ int main(int argc, char ** argv) {
return 0;
}
+ if (params.rope_freq_base != 10000.0) {
+ fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
+ }
+
+ if (params.rope_freq_scale != 1.0) {
+ fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
+ }
+
if (params.n_ctx > 2048) {
- fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
- "expect poor results\n", __func__, params.n_ctx);
+ // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
+ fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
@@ -131,17 +137,14 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
- // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
+ // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters
// uncomment the "used_mem" line in llama.cpp to see the results
if (params.mem_test) {
{
- const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
- llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
- }
+ fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
- {
- const std::vector<llama_token> tmp = { 0, };
- llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
+ const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
+ llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
}
llama_print_timings(ctx);
@@ -319,6 +322,10 @@ int main(int argc, char ** argv) {
}
}
+ if (params.input_prefix_bos) {
+ fprintf(stderr, "Input prefix with BOS\n");
+ }
+
if (!params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
@@ -332,13 +339,38 @@ int main(int argc, char ** argv) {
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
+ grammar_parser::parse_state parsed_grammar;
+ llama_grammar * grammar = NULL;
+ if (!params.grammar.empty()) {
+ parsed_grammar = grammar_parser::parse(params.grammar.c_str());
+ // will be empty (default) if there are parse errors
+ if (parsed_grammar.rules.empty()) {
+ return 1;
+ }
+ fprintf(stderr, "%s: grammar:\n", __func__);
+ grammar_parser::print_grammar(stderr, parsed_grammar);
+ fprintf(stderr, "\n");
+
+ {
+ auto it = params.logit_bias.find(llama_token_eos());
+ if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ fprintf(stderr,
+ "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
+ }
+ }
+
+ std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
+ grammar = llama_grammar_init(
+ grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ }
+
// TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
if (params.interactive) {
const char *control_message;
- if (con_st.multiline_input) {
+ if (params.multiline_input) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
@@ -366,7 +398,7 @@ int main(int argc, char ** argv) {
int n_past_guidance = 0;
// the first thing we will do is to output the prompt, so set color accordingly
- console_set_color(con_st, CONSOLE_COLOR_PROMPT);
+ console::set_display(console::prompt);
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
@@ -387,9 +419,9 @@ int main(int argc, char ** argv) {
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
if ((int)embd.size() > max_embd_size) {
auto skipped_tokens = embd.size() - max_embd_size;
- console_set_color(con_st, CONSOLE_COLOR_ERROR);
+ console::set_display(console::error);
printf("<<input too long: skipped %zu token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
- console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
+ console::set_display(console::reset);
fflush(stdout);
embd.resize(max_embd_size);
}
@@ -549,7 +581,7 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (ctx_guidance) {
- llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
+ llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale);
}
// Apply penalties
@@ -565,6 +597,10 @@ int main(int argc, char ** argv) {
logits[llama_token_nl()] = nl_logit;
}
+ if (grammar != NULL) {
+ llama_sample_grammar(ctx, &candidates_p, grammar);
+ }
+
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
@@ -590,20 +626,14 @@ int main(int argc, char ** argv) {
}
// printf("`%d`", candidates_p.size);
+ if (grammar != NULL) {
+ llama_grammar_accept_token(ctx, grammar, id);
+ }
+
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
}
- // replace end of text token with newline token when in interactive mode
- if (id == llama_token_eos() && params.interactive && !params.instruct) {
- id = llama_token_newline.front();
- if (params.antiprompt.size() != 0) {
- // tokenize and inject first reverse prompt
- const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
- embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
- }
- }
-
// add it to the context
embd.push_back(id);
@@ -634,7 +664,7 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (input_echo && (int)embd_inp.size() == n_consumed) {
- console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
+ console::set_display(console::reset);
}
// if not currently processing queued inputs;
@@ -660,7 +690,7 @@ int main(int argc, char ** argv) {
if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
if (params.interactive) {
is_interacting = true;
- console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
+ console::set_display(console::user_input);
}
is_antiprompt = true;
fflush(stdout);
@@ -669,11 +699,34 @@ int main(int argc, char ** argv) {
}
}
+ // deal with end of text token in interactive mode
+ if (last_n_tokens.back() == llama_token_eos()) {
+ if (params.interactive) {
+ if (params.antiprompt.size() != 0) {
+ // tokenize and inject first reverse prompt
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
+ is_antiprompt = true;
+ }
+
+ is_interacting = true;
+ printf("\n");
+ console::set_display(console::user_input);
+ fflush(stdout);
+ } else if (params.instruct) {
+ is_interacting = true;
+ }
+ }
+
if (n_past > 0 && is_interacting) {
if (params.instruct) {
printf("\n> ");
}
+ if (params.input_prefix_bos) {
+ embd_inp.push_back(llama_token_bos());
+ }
+
std::string buffer;
if (!params.input_prefix.empty()) {
buffer += params.input_prefix;
@@ -683,12 +736,12 @@ int main(int argc, char ** argv) {
std::string line;
bool another_line = true;
do {
- another_line = console_readline(con_st, line);
+ another_line = console::readline(line, params.multiline_input);
buffer += line;
} while (another_line);
// done taking input, reset color
- console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
+ console::set_display(console::reset);
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
@@ -720,18 +773,26 @@ int main(int argc, char ** argv) {
}
if (n_past > 0) {
+ if (is_interacting) {
+ // reset grammar state if we're restarting generation
+ if (grammar != NULL) {
+ llama_grammar_free(grammar);
+
+ std::vector<const llama_grammar_element *> grammar_rules(
+ parsed_grammar.c_rules());
+ grammar = llama_grammar_init(
+ grammar_rules.data(), grammar_rules.size(),
+ parsed_grammar.symbol_ids.at("root"));
+ }
+ }
is_interacting = false;
}
}
// end of text token
- if (!embd.empty() && embd.back() == llama_token_eos()) {
- if (params.instruct) {
- is_interacting = true;
- } else {
- fprintf(stderr, " [end of text]\n");
- break;
- }
+ if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
+ fprintf(stderr, " [end of text]\n");
+ break;
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
@@ -751,6 +812,9 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
+ if (grammar != NULL) {
+ llama_grammar_free(grammar);
+ }
llama_backend_free();
return 0;