aboutsummaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorEvan Jones <evan.q.jones@gmail.com>2023-07-23 23:58:10 -0400
committerGitHub <noreply@github.com>2023-07-23 23:58:10 -0400
commit84e09a7d8bc4ab6d658b5cd81295ac0add60be78 (patch)
tree934c5480d917325ac8baa29f4edfae99137b56bb /llama.h
parent2f9cf974a066ac0e03fbb235d834b01b0164d743 (diff)
llama : add grammar-based sampling (#1773)
* llama, main : constrain sampling to grammar * allow loading grammar from file * fix whitespace errors * handle & print parser errors * add comments to grammar syntax and allow newlines where unambiguous * add missing include * support alternates in root rule * fix bugs with empty token and EOS * adjust JSON grammar * remove swp file * rewrite ternary expressions Co-authored-by: Henri Vasserman <henv@hot.ee> * use struct for grammar elements and add Unicode support * add unicode escapes * add inverse char ranges * only sample full tokens (no peeking or truncation) * llama : minor style changes blindly applied in online editor - hopefully I didn't break something * update help text * add warning message if EOS is disabled --------- Co-authored-by: Henri Vasserman <henv@hot.ee> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h49
1 files changed, 49 insertions, 0 deletions
diff --git a/llama.h b/llama.h
index 1089909..81a30e1 100644
--- a/llama.h
+++ b/llama.h
@@ -141,6 +141,40 @@ extern "C" {
bool quantize_output_tensor; // quantize output.weight
} llama_model_quantize_params;
+ // grammar types
+ struct llama_grammar;
+
+ // grammar element type
+ enum llama_gretype {
+ // end of rule definition
+ LLAMA_GRETYPE_END = 0,
+
+ // start of alternate definition for rule
+ LLAMA_GRETYPE_ALT = 1,
+
+ // non-terminal element: reference to rule
+ LLAMA_GRETYPE_RULE_REF = 2,
+
+ // terminal element: character (code point)
+ LLAMA_GRETYPE_CHAR = 3,
+
+ // inverse char(s) ([^a], [^a-b] [^abc])
+ LLAMA_GRETYPE_CHAR_NOT = 4,
+
+ // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+ // be an inclusive range ([a-z])
+ LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
+
+ // modifies a preceding LLAMA_GRETYPE_CHAR or
+ // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+ LLAMA_GRETYPE_CHAR_ALT = 6,
+ };
+
+ typedef struct llama_grammar_element {
+ enum llama_gretype type;
+ uint32_t value; // Unicode code point or rule ID
+ } llama_grammar_element;
+
// performance timing information
struct llama_timings {
double t_start_ms;
@@ -333,6 +367,15 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line
+ // Grammar
+ //
+ LLAMA_API struct llama_grammar * llama_grammar_init(
+ const llama_grammar_element ** rules,
+ size_t n_rules,
+ size_t start_rule_index);
+
+ LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
+
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@@ -367,6 +410,9 @@ extern "C" {
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
+ /// @details Apply constraints from grammar
+ LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
+
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -388,6 +434,9 @@ extern "C" {
/// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
+ /// @details Accepts the sampled token into the grammar
+ LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
+
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx);