diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 483 |
1 files changed, 376 insertions, 107 deletions
@@ -28,6 +28,7 @@ #include <atomic> #include <mutex> #include <sstream> +#include <numeric> #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1475,109 +1476,402 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co // sampling // -static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) { - // find the top k tokens - std::partial_sort( - logits_id.begin(), - logits_id.begin() + top_k, logits_id.end(), - [](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) { - return a.first > b.first; - }); +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { + assert(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); - logits_id.resize(top_k); + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = candidates->data[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } -static llama_vocab::id llama_sample_top_p_top_k( - llama_context & lctx, - const std::vector<llama_vocab::id> & last_n_tokens, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - auto & rng = lctx.rng; - - const int n_logits = lctx.model.hparams.n_vocab; - - const auto & logits = lctx.logits; - const auto * plogits = logits.data() + logits.size() - n_logits; - - if (temp <= 0) { - // select the token with the highest logit directly - float max_logit = plogits[0]; - llama_vocab::id max_id = 0; - - for (int i = 1; i < n_logits; ++i) { - if (plogits[i] > max_logit) { - max_logit = plogits[i]; - max_id = i; - } +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { + const int64_t t_start_sample_us = ggml_time_us(); + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); + + // Sort scores in descending order + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates->size) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); } - return max_id; + candidates->sorted = true; } + candidates->size = k; - std::vector<std::pair<float, llama_vocab::id>> logits_id; - logits_id.reserve(n_logits); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} - { - const float scale = 1.0f/temp; - for (int i = 0; i < n_logits; ++i) { - // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) - // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { - // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (plogits[i] < 0.0f) { - logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); - } - } else { - logits_id.push_back(std::make_pair(plogits[i]*scale, i)); - } +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; + + // Check if the running sum is greater than p or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep) { + last_idx = i; + break; } } - sample_top_k(logits_id, top_k > 0 ? std::min(top_k, n_logits) : n_logits); + // Resize the output vector to keep only the top-p tokens + candidates->size = last_idx; - // compute probs for the top k tokens - std::vector<float> probs; - probs.reserve(logits_id.size()); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} - float maxl = logits_id[0].first; - double sum = 0.0; - for (const auto & kv : logits_id) { - const float p = expf(kv.first - maxl); - probs.push_back(p); - sum += p; +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { + return; } - // normalize the probs - for (auto & p : probs) { - p /= sum; + const int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates); + + // Compute the first and second derivatives + std::vector<float> first_derivatives(candidates->size - 1); + std::vector<float> second_derivatives(candidates->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; } - if (top_p < 1.0) { - double cumsum = 0.0; - for (int i = 0; i < (int) probs.size(); i++) { - cumsum += probs[i]; - if (cumsum >= top_p) { - probs.resize(i + 1); - logits_id.resize(i + 1); - break; - } + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = abs(second_derivatives[i]); + } + + // Normalize the second derivatives + float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; } } - //printf("\n"); - //for (int i = 0; i < (int) 10; i++) { - // printf("%d: '%s' %f\n", i, lctx.vocab.id_to_token.at(logits_id[i].second).tok.c_str(), probs[i]); - //} - //printf("\n\n"); - //exit(0); + // Resize the output vector to keep only the tokens above the tail location + candidates->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(nullptr, candidates); + + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector<float> shifted_scores; + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector<size_t> indices(candidates->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector<llama_token_data> new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates->data[idx]); + } + + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= temp; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty) { + if (last_tokens_size == 0 || penalty == 1.0f) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates->size; ++i) { + auto token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); + if (token_iter == last_tokens + last_tokens_size) { + continue; + } + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty; + } else { + candidates->data[i].logit /= penalty; + } + } + + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map<llama_token, int> token_count; + for (size_t i = 0; i < last_tokens_size; ++i) { + token_count[last_tokens_p[i]]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates->size; ++i) { + auto token_iter = token_count.find(candidates->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + int count = token_iter->second; + candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; + } + + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { + assert(ctx); + auto N = float(llama_n_vocab(ctx)); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_sample_top_k(nullptr, candidates, int(k)); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } + return X; +} + +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { + assert(ctx); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + // Normalize the probabilities of the remaining words + llama_sample_softmax(ctx, candidates); + + // Sample the next word X from the remaining words + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { + const int64_t t_start_sample_us = ggml_time_us(); + + // Find max element + auto max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } + return result; +} + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates); + + std::vector<float> probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } std::discrete_distribution<> dist(probs.begin(), probs.end()); + auto & rng = ctx->rng; int idx = dist(rng); - return logits_id[idx].second; + llama_token result = candidates->data[idx].id; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + return result; } // @@ -2348,33 +2642,8 @@ llama_token llama_token_eos() { return 2; } -llama_token llama_sample_top_p_top_k( - llama_context * ctx, - const llama_token * last_n_tokens_data, - int last_n_tokens_size, - int top_k, - float top_p, - float temp, - float repeat_penalty) { - const int64_t t_start_sample_us = ggml_time_us(); - - llama_token result = 0; - - // TODO: avoid this ... - const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size); - - result = llama_sample_top_p_top_k( - *ctx, - last_n_tokens, - top_k, - top_p, - temp, - repeat_penalty); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - - return result; +llama_token llama_token_nl() { + return 13; } |