diff options
Diffstat (limited to 'utils.h')
-rw-r--r-- | utils.h | 19 |
1 files changed, 6 insertions, 13 deletions
@@ -19,7 +19,7 @@ struct gpt_params { int32_t repeat_last_n = 64; // last n tokens to penalize // sampling parameters - int32_t top_k = 40; // unused + int32_t top_k = 40; float top_p = 0.95f; float temp = 0.80f; float repeat_penalty = 1.30f; @@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // - consider only the top K tokens // - from them, consider only the top tokens with cumulative probability > P // -// TODO: not sure if this implementation is correct -// TODO: temperature is not implemented -// -gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const float * logits, - int top_k, - double top_p, - double temp, - std::mt19937 & rng); - -gpt_vocab::id llama_sample_top_p( +gpt_vocab::id llama_sample_top_p_top_k( const gpt_vocab & vocab, const float * logits, std::vector<gpt_vocab::id> & last_n_tokens, double repeat_penalty, + int top_k, double top_p, double temp, std::mt19937 & rng); +// filer to top K tokens from list of logits +void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k); + // // Quantization // |