aboutsummaryrefslogtreecommitdiff
path: root/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils.h')
-rw-r--r--utils.h19
1 files changed, 6 insertions, 13 deletions
diff --git a/utils.h b/utils.h
index e331904..5b3d736 100644
--- a/utils.h
+++ b/utils.h
@@ -19,7 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize
// sampling parameters
- int32_t top_k = 40; // unused
+ int32_t top_k = 40;
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.30f;
@@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
-// TODO: not sure if this implementation is correct
-// TODO: temperature is not implemented
-//
-gpt_vocab::id gpt_sample_top_k_top_p(
- const gpt_vocab & vocab,
- const float * logits,
- int top_k,
- double top_p,
- double temp,
- std::mt19937 & rng);
-
-gpt_vocab::id llama_sample_top_p(
+gpt_vocab::id llama_sample_top_p_top_k(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
+ int top_k,
double top_p,
double temp,
std::mt19937 & rng);
+// filer to top K tokens from list of logits
+void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
+
//
// Quantization
//