diff options
author | Evan Jones <evan.q.jones@gmail.com> | 2023-07-23 23:58:10 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 23:58:10 -0400 |
commit | 84e09a7d8bc4ab6d658b5cd81295ac0add60be78 (patch) | |
tree | 934c5480d917325ac8baa29f4edfae99137b56bb /examples/main/main.cpp | |
parent | 2f9cf974a066ac0e03fbb235d834b01b0164d743 (diff) |
llama : add grammar-based sampling (#1773)
* llama, main : constrain sampling to grammar
* allow loading grammar from file
* fix whitespace errors
* handle & print parser errors
* add comments to grammar syntax and allow newlines where unambiguous
* add missing include
* support alternates in root rule
* fix bugs with empty token and EOS
* adjust JSON grammar
* remove swp file
* rewrite ternary expressions
Co-authored-by: Henri Vasserman <henv@hot.ee>
* use struct for grammar elements and add Unicode support
* add unicode escapes
* add inverse char ranges
* only sample full tokens (no peeking or truncation)
* llama : minor style changes
blindly applied in online editor - hopefully I didn't break something
* update help text
* add warning message if EOS is disabled
---------
Co-authored-by: Henri Vasserman <henv@hot.ee>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3bd8ba2..16ddc22 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include <cassert> #include <cinttypes> @@ -337,6 +338,31 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector<llama_token> last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -570,6 +596,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -595,6 +625,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -725,6 +759,18 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector<const llama_grammar_element *> grammar_rules( + parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -756,6 +802,9 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + if (grammar != NULL) { + llama_grammar_free(grammar); + } llama_backend_free(); return 0; |