diff options
author | Evan Jones <evan.q.jones@gmail.com> | 2023-07-23 23:58:10 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 23:58:10 -0400 |
commit | 84e09a7d8bc4ab6d658b5cd81295ac0add60be78 (patch) | |
tree | 934c5480d917325ac8baa29f4edfae99137b56bb | |
parent | 2f9cf974a066ac0e03fbb235d834b01b0164d743 (diff) |
llama : add grammar-based sampling (#1773)
* llama, main : constrain sampling to grammar
* allow loading grammar from file
* fix whitespace errors
* handle & print parser errors
* add comments to grammar syntax and allow newlines where unambiguous
* add missing include
* support alternates in root rule
* fix bugs with empty token and EOS
* adjust JSON grammar
* remove swp file
* rewrite ternary expressions
Co-authored-by: Henri Vasserman <henv@hot.ee>
* use struct for grammar elements and add Unicode support
* add unicode escapes
* add inverse char ranges
* only sample full tokens (no peeking or truncation)
* llama : minor style changes
blindly applied in online editor - hopefully I didn't break something
* update help text
* add warning message if EOS is disabled
---------
Co-authored-by: Henri Vasserman <henv@hot.ee>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r-- | Makefile | 5 | ||||
-rw-r--r-- | examples/CMakeLists.txt | 2 | ||||
-rw-r--r-- | examples/common.cpp | 24 | ||||
-rw-r--r-- | examples/common.h | 1 | ||||
-rw-r--r-- | examples/grammar-parser.cpp | 423 | ||||
-rw-r--r-- | examples/grammar-parser.h | 29 | ||||
-rw-r--r-- | examples/main/main.cpp | 49 | ||||
-rw-r--r-- | grammars/arithmetic.gbnf | 6 | ||||
-rw-r--r-- | grammars/chess.gbnf | 13 | ||||
-rw-r--r-- | grammars/japanese.gbnf | 7 | ||||
-rw-r--r-- | grammars/json.gbnf | 29 | ||||
-rw-r--r-- | grammars/list.gbnf | 4 | ||||
-rw-r--r-- | llama.cpp | 337 | ||||
-rw-r--r-- | llama.h | 49 |
14 files changed, 977 insertions, 1 deletions
@@ -323,6 +323,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) @@ -333,7 +336,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 161960b..4b1f1cf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + grammar-parser.h + grammar-parser.cpp ) if (BUILD_SHARED_LIBS) diff --git a/examples/common.cpp b/examples/common.cpp index 7a1928f..779605f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -438,6 +438,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.grammar = argv[i]; + } else if (arg == "--grammar-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator<char>(file), + std::istreambuf_iterator<char>(), + std::back_inserter(params.grammar) + ); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -514,6 +536,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); + fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); diff --git a/examples/common.h b/examples/common.h index fb8f6d6..7086606 100644 --- a/examples/common.h +++ b/examples/common.h @@ -63,6 +63,7 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp new file mode 100644 index 0000000..019d5e1 --- /dev/null +++ b/examples/grammar-parser.cpp @@ -0,0 +1,423 @@ +#include "grammar-parser.h" +#include <cstdint> +#include <cwchar> +#include <string> +#include <utility> +#include <stdexcept> +#include <exception> + +namespace grammar_parser { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + std::pair<uint32_t, const char *> decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast<uint8_t>(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector<llama_grammar_element> & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + + bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); + } + + const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; + } + + const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; + } + + std::pair<uint32_t, const char *> parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector<llama_grammar_element> & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = out_elements.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = out_elements.size(); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < out_elements.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + out_elements.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + last_sym_start = out_elements.size(); + // output reference to synthesized rule + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector<llama_grammar_element> sub_rule; + // add preceding symbol to generated rule + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + } + // mark start of alternate def + sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + } + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); + + // in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + + pos = parse_space(pos + 1, is_nested); + } else { + break; + } + } + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + std::vector<llama_grammar_element> rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); + return pos; + } + + const char * parse_rule(parse_state & state, const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(state, pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + + parse_state parse(const char * src) { + try { + parse_state state; + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(state, pos); + } + return state; + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return parse_state(); + } + } + + void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast<char>(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "<U+%04X>", c); + } + } + + bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } + + void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector<llama_grammar_element> & rule, + const std::map<uint32_t, std::string> & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); + } + + void print_grammar(FILE * file, const parse_state & state) { + try { + std::map<uint32_t, std::string> symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, i, state.rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } + } + + std::vector<const llama_grammar_element *> parse_state::c_rules() { + std::vector<const llama_grammar_element *> ret; + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; + } +} diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h new file mode 100644 index 0000000..9037d72 --- /dev/null +++ b/examples/grammar-parser.h @@ -0,0 +1,29 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by llama.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include "llama.h" +#include <vector> +#include <map> +#include <cstdint> +#include <string> + +namespace grammar_parser { + struct parse_state { + std::map<std::string, uint32_t> symbol_ids; + std::vector<std::vector<llama_grammar_element>> rules; + + std::vector<const llama_grammar_element *> c_rules(); + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3bd8ba2..16ddc22 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include <cassert> #include <cinttypes> @@ -337,6 +338,31 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector<llama_token> last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -570,6 +596,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -595,6 +625,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -725,6 +759,18 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector<const llama_grammar_element *> grammar_rules( + parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -756,6 +802,9 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + if (grammar != NULL) { + llama_grammar_free(grammar); + } llama_backend_free(); return 0; diff --git a/grammars/arithmetic.gbnf b/grammars/arithmetic.gbnf new file mode 100644 index 0000000..3aa95a9 --- /dev/null +++ b/grammars/arithmetic.gbnf @@ -0,0 +1,6 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 0000000..ef0fc1b --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,13 @@ +# Specifies chess moves as a list in algebraic notation, using PGN conventions + +# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern +root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ +move ::= (pawn | nonpawn | castle) [+#]? + +# piece type, optional file/rank, optional capture, dest file & rank +nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] + +# optional file & capture, dest file & rank, optional promotion +pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? + +castle ::= "O-O" "-O"? diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 0000000..43f25ab --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [ぁ-ゟ] +katakana ::= [ァ-ヿ] +punctuation ::= [、-〾] +cjk ::= [一-鿿] diff --git a/grammars/json.gbnf b/grammars/json.gbnf new file mode 100644 index 0000000..40fa2b6 --- /dev/null +++ b/grammars/json.gbnf @@ -0,0 +1,29 @@ +# Grammar for subset of JSON - doesn't support full string or number syntax + +root ::= object +value ::= object | array | string | number | boolean | "null" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +# Only plain integers currently +number ::= "-"? [0-9]+ ws +boolean ::= ("true" | "false") ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 0000000..51e6c9c --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" @@ -1966,6 +1966,279 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co } // +// grammar - internal +// + +struct llama_grammar { + const std::vector<std::vector<llama_grammar_element>> rules; + std::vector<std::vector<const llama_grammar_element *>> stacks; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; +}; + +// NOTE: assumes valid utf8 (but checks for overrun) +// adds a terminating 0 for use as pointer +std::vector<uint32_t> decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + const char * pos = src; + std::vector<uint32_t> code_points; + while (*pos != 0) { + uint8_t first_byte = static_cast<uint8_t>(*pos); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = pos + len; // may overrun! + ++pos; + for ( ; pos < end && *pos != 0; ++pos) { + value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); + } + code_points.push_back(value); + } + code_points.push_back(0); + return code_points; +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; + case LLAMA_GRETYPE_ALT: return true; + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector<std::vector<llama_grammar_element>> & rules, + const std::vector<const llama_grammar_element *> & stack, + std::vector<std::vector<const llama_grammar_element *>> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast<size_t>(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + LLAMA_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept( + const std::vector<std::vector<llama_grammar_element>> & rules, + const std::vector<std::vector<const llama_grammar_element *>> & stacks, + const uint32_t chr) { + + std::vector<std::vector<const llama_grammar_element *>> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates( + const std::vector<std::vector<llama_grammar_element>> & rules, + const std::vector<std::vector<const llama_grammar_element *>> & stacks, + const std::vector<llama_grammar_candidate> & candidates); + +static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( + const std::vector<std::vector<llama_grammar_element>> & rules, + const std::vector<const llama_grammar_element *> & stack, + const std::vector<llama_grammar_candidate> & candidates) { + + std::vector<llama_grammar_candidate> rejects; + + if (stack.empty()) { + // accept nothing; EOS is handled elsewhere + rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector<llama_grammar_candidate> next_candidates; + for (auto tok : candidates) { + if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { + if (tok.code_points[1] != 0) { + next_candidates.push_back({ tok.index, tok.code_points + 1 }); + } + } else { + rejects.push_back(tok); + } + } + + auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector<std::vector<const llama_grammar_element *>> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1 }); + } + + return rejects; +} + +static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates( + const std::vector<std::vector<llama_grammar_element>> & rules, + const std::vector<std::vector<const llama_grammar_element *>> & stacks, + const std::vector<llama_grammar_candidate> & candidates) { + LLAMA_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector<llama_grammar_candidate>(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + std::vector<std::vector<llama_grammar_element>> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector<std::vector<const llama_grammar_element *>> stacks; + pos = rules[start_rule_index]; + do { + std::vector<const llama_grammar_element *> stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + +// // sampling // @@ -2250,6 +2523,47 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = llama_token_eos(); + + std::vector<std::vector<uint32_t>> candidates_decoded; + std::vector<llama_grammar_candidate> candidates_grammar; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (*str == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(str)); + candidates_grammar.push_back({ i, candidates_decoded.back().data() }); + } + } + + const auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -2425,6 +2739,29 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos()) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + LLAMA_ASSERT(false); + } + + const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string + auto code_points = decode_utf8(str); + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + } + LLAMA_ASSERT(!grammar->stacks.empty()); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + // // quantization // @@ -141,6 +141,40 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + // performance timing information struct llama_timings { double t_start_ms; @@ -333,6 +367,15 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + // Grammar + // + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -367,6 +410,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -388,6 +434,9 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx); |