From 84e09a7d8bc4ab6d658b5cd81295ac0add60be78 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 23 Jul 2023 23:58:10 -0400 Subject: llama : add grammar-based sampling (#1773) * llama, main : constrain sampling to grammar * allow loading grammar from file * fix whitespace errors * handle & print parser errors * add comments to grammar syntax and allow newlines where unambiguous * add missing include * support alternates in root rule * fix bugs with empty token and EOS * adjust JSON grammar * remove swp file * rewrite ternary expressions Co-authored-by: Henri Vasserman * use struct for grammar elements and add Unicode support * add unicode escapes * add inverse char ranges * only sample full tokens (no peeking or truncation) * llama : minor style changes blindly applied in online editor - hopefully I didn't break something * update help text * add warning message if EOS is disabled --------- Co-authored-by: Henri Vasserman Co-authored-by: Georgi Gerganov --- examples/main/main.cpp | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) (limited to 'examples/main') diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3bd8ba2..16ddc22 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include #include @@ -337,6 +338,31 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -570,6 +596,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -595,6 +625,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -725,6 +759,18 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector grammar_rules( + parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -756,6 +802,9 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + if (grammar != NULL) { + llama_grammar_free(grammar); + } llama_backend_free(); return 0; -- cgit v1.2.3