diff options
author | Bach Le <bach@bullno1.com> | 2023-07-12 00:18:43 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-11 19:18:43 +0300 |
commit | c9c74b4e3f9dcfab8b0032749ff8a579ab4e4d8d (patch) | |
tree | 651d6915218efa83cad8745310f7d1114ca21e2a /llama.h | |
parent | 3ec7e596b2ba3f43c22f441254ca2bcfa91102ba (diff) |
llama : add classifier-free guidance (#2135)
* Initial implementation
* Remove debug print
* Restore signature of llama_init_from_gpt_params
* Free guidance context
* Make freeing of guidance_ctx conditional
* Make Classifier-Free Guidance a sampling function
* Correct typo. CFG already means context-free grammar.
* Record sampling time in llama_sample_classifier_free_guidance
* Shift all values by the max value before applying logsoftmax
* Fix styling based on review
Diffstat (limited to 'llama.h')
-rw-r--r-- | llama.h | 12 |
1 files changed, 12 insertions, 0 deletions
@@ -309,6 +309,18 @@ extern "C" { /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. + LLAMA_API void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale, + float smooth_factor); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); |