aboutsummaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorIvan Stepanov <ivanstepanovftw@gmail.com>2023-04-29 08:34:41 +0300
committerGitHub <noreply@github.com>2023-04-29 08:34:41 +0300
commitdd7eff57d8491792010b1002b8de6a4b54912e5c (patch)
treeed7f7c85ef220cafca40976b52bfeac948b3c673 /llama.h
parent7fc50c051ae8a78e9643fdf172d12e20f2dd9b6c (diff)
llama : new sampling algorithms (#1126)
* Sample interface, new samplers. New samplers: - locally typical sampling - tail free sampling - frequency and presence penalty - mirostat Ignore EOS fix: -inf should be used. * mirostat * Added --logit-bias and --no-penalize-nl, removed std::span * Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) Use C++11, clarify llama API documentation, rename Mirostat parameters to --mirostat_lr and --mirostat_ent, add temperature sampling for Mirostat, simplify Mirostat sampling API parameters (removed N and *k) * Save and load example adjust * Tests * Windows build fix * Windows test fix
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h64
1 files changed, 52 insertions, 12 deletions
diff --git a/llama.h b/llama.h
index 936c521..34a8f5b 100644
--- a/llama.h
+++ b/llama.h
@@ -39,12 +39,16 @@ extern "C" {
typedef struct llama_token_data {
llama_token id; // token id
-
+ float logit; // log-odds of the token
float p; // probability of the token
- float plog; // log probability of the token
-
} llama_token_data;
+ typedef struct llama_token_data_array {
+ llama_token_data * data;
+ size_t size;
+ bool sorted;
+ } llama_token_data_array;
+
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
@@ -181,16 +185,52 @@ extern "C" {
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
+ LLAMA_API llama_token llama_token_nl();
+
+ // Sampling functions
+
+ /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+ LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
+
+ /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+ LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
+
+ /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+ LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
+
+ /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
+
+ /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+
+ /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
+ LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
+
+ /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
+ LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
+ LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
+
+ /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
+ /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+ /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
+
+ /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
+ /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
+
+ /// @details Selects the token with the highest probability.
+ LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
- // TODO: improve the last_n_tokens interface ?
- LLAMA_API llama_token llama_sample_top_p_top_k(
- struct llama_context * ctx,
- const llama_token * last_n_tokens_data,
- int last_n_tokens_size,
- int top_k,
- float top_p,
- float temp,
- float repeat_penalty);
+ /// @details Randomly selects a token from the candidates based on their probabilities.
+ LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);