aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHoward Su <howard0su@gmail.com>2023-07-05 18:31:23 +0800
committerGitHub <noreply@github.com>2023-07-05 18:31:23 +0800
commit051c70dcd55709c9cbbfa849af035951fe720433 (patch)
tree2d339ce8cdec9d9a875b5a06dcb7857705dd4be6
parent9e4475f5cf639315f61ed7b8da6258bb0c7c5ca9 (diff)
llama: Don't double count the sampling time (#2107)
-rw-r--r--llama.cpp20
1 files changed, 9 insertions, 11 deletions
diff --git a/llama.cpp b/llama.cpp
index 83e93ef..e04fbfc 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1905,10 +1905,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
return;
}
- const int64_t t_start_sample_us = ggml_time_us();
-
llama_sample_softmax(ctx, candidates);
+ const int64_t t_start_sample_us = ggml_time_us();
+
// Compute the cumulative probabilities
float cum_sum = 0.0f;
size_t last_idx = candidates->size;
@@ -1937,9 +1937,8 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
return;
}
- const int64_t t_start_sample_us = ggml_time_us();
-
llama_sample_softmax(nullptr, candidates);
+ const int64_t t_start_sample_us = ggml_time_us();
// Compute the first and second derivatives
std::vector<float> first_derivatives(candidates->size - 1);
@@ -1991,11 +1990,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c
return;
}
- const int64_t t_start_sample_us = ggml_time_us();
-
// Compute the softmax of logits and calculate entropy
llama_sample_softmax(nullptr, candidates);
+ const int64_t t_start_sample_us = ggml_time_us();
+
float entropy = 0.0f;
for (size_t i = 0; i < candidates->size; ++i) {
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
@@ -2164,13 +2163,11 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
- ctx->n_sample++;
}
return X;
}
llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
- assert(ctx);
int64_t t_start_sample_us;
t_start_sample_us = ggml_time_us();
@@ -2185,13 +2182,14 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok
candidates->size = 1;
}
+ if (ctx) {
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+
// Normalize the probabilities of the remaining words
llama_sample_softmax(ctx, candidates);
// Sample the next word X from the remaining words
- if (ctx) {
- ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
llama_token X = llama_sample_token(ctx, candidates);
t_start_sample_us = ggml_time_us();