aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp14
1 files changed, 2 insertions, 12 deletions
diff --git a/llama.cpp b/llama.cpp
index 23e746d..3b0024e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
- float scale,
- float smooth_factor) {
+ float scale) {
int64_t t_start_sample_us = ggml_time_us();
assert(ctx);
@@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance(
for (int i = 0; i < n_vocab; ++i) {
float logit_guidance = logits_guidance[i];
float logit_base = logits_base[i];
- logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
- }
-
- llama_log_softmax(logits_guidance, n_vocab);
-
- for (int i = 0; i < n_vocab; ++i) {
- float logit_base = logits_base[i];
- float logit_guidance = logits_guidance[i];
-
- candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+ candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
}
if (ctx) {