aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorSIGSEGV <21287366+akr2002@users.noreply.github.com>2023-07-12 19:18:43 +0530
committerGitHub <noreply@github.com>2023-07-12 19:18:43 +0530
commit2516af4cd61f509c995b4f78fdf123cba33f3509 (patch)
treede7324f01b9454fb30e4d827b8300d02fd982ed3 /llama.cpp
parentff34a7d385fc47c4d432fd8c19306d5aca814d05 (diff)
parent4e7464ef88885cb3532738b03cac890f4077fa20 (diff)
Merge branch 'ggerganov:master' into master
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index 08ec21a..2d09d6c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2167,6 +2167,62 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
+static void llama_log_softmax(float * array, size_t size) {
+ float max_l = *std::max_element(array, array + size);
+ float sum = 0.f;
+ for (size_t i = 0; i < size; ++i) {
+ float p = expf(array[i] - max_l);
+ sum += p;
+ array[i] = p;
+ }
+
+ for (size_t i = 0; i < size; ++i) {
+ array[i] = logf(array[i] / sum);
+ }
+}
+
+void llama_sample_classifier_free_guidance(
+ struct llama_context * ctx,
+ llama_token_data_array * candidates,
+ struct llama_context * guidance_ctx,
+ float scale,
+ float smooth_factor) {
+ int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
+
+ assert(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ assert(n_vocab == (int)candidates->size);
+ assert(!candidates->sorted);
+
+ std::vector<float> logits_base;
+ logits_base.reserve(candidates->size);
+ for (size_t i = 0; i < candidates->size; ++i) {
+ logits_base.push_back(candidates->data[i].logit);
+ }
+ llama_log_softmax(logits_base.data(), candidates->size);
+
+ float* logits_guidance = llama_get_logits(guidance_ctx);
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_guidance = logits_guidance[i];
+ float logit_base = logits_base[i];
+ logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
+ }
+
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_base = logits_base[i];
+ float logit_guidance = logits_guidance[i];
+
+ candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+ }
+
+ if (ctx) {
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx);