aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index 08ec21a..2d09d6c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2167,6 +2167,62 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
+static void llama_log_softmax(float * array, size_t size) {
+ float max_l = *std::max_element(array, array + size);
+ float sum = 0.f;
+ for (size_t i = 0; i < size; ++i) {
+ float p = expf(array[i] - max_l);
+ sum += p;
+ array[i] = p;
+ }
+
+ for (size_t i = 0; i < size; ++i) {
+ array[i] = logf(array[i] / sum);
+ }
+}
+
+void llama_sample_classifier_free_guidance(
+ struct llama_context * ctx,
+ llama_token_data_array * candidates,
+ struct llama_context * guidance_ctx,
+ float scale,
+ float smooth_factor) {
+ int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
+
+ assert(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ assert(n_vocab == (int)candidates->size);
+ assert(!candidates->sorted);
+
+ std::vector<float> logits_base;
+ logits_base.reserve(candidates->size);
+ for (size_t i = 0; i < candidates->size; ++i) {
+ logits_base.push_back(candidates->data[i].logit);
+ }
+ llama_log_softmax(logits_base.data(), candidates->size);
+
+ float* logits_guidance = llama_get_logits(guidance_ctx);
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_guidance = logits_guidance[i];
+ float logit_base = logits_base[i];
+ logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
+ }
+
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_base = logits_base[i];
+ float logit_guidance = logits_guidance[i];
+
+ candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+ }
+
+ if (ctx) {
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx);