diff options
author | Howard Su <howard0su@gmail.com> | 2023-07-05 18:31:23 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-05 18:31:23 +0800 |
commit | 051c70dcd55709c9cbbfa849af035951fe720433 (patch) | |
tree | 2d339ce8cdec9d9a875b5a06dcb7857705dd4be6 | |
parent | 9e4475f5cf639315f61ed7b8da6258bb0c7c5ca9 (diff) |
llama: Don't double count the sampling time (#2107)
-rw-r--r-- | llama.cpp | 20 |
1 files changed, 9 insertions, 11 deletions
@@ -1905,10 +1905,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can return; } - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(ctx, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + // Compute the cumulative probabilities float cum_sum = 0.0f; size_t last_idx = candidates->size; @@ -1937,9 +1937,8 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * return; } - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); // Compute the first and second derivatives std::vector<float> first_derivatives(candidates->size - 1); @@ -1991,11 +1990,11 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c return; } - const int64_t t_start_sample_us = ggml_time_us(); - // Compute the softmax of logits and calculate entropy llama_sample_softmax(nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { entropy += -candidates->data[i].p * logf(candidates->data[i].p); @@ -2164,13 +2163,11 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; } return X; } llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - assert(ctx); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -2185,13 +2182,14 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok candidates->size = 1; } + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + // Normalize the probabilities of the remaining words llama_sample_softmax(ctx, candidates); // Sample the next word X from the remaining words - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } llama_token X = llama_sample_token(ctx, candidates); t_start_sample_us = ggml_time_us(); |