aboutsummaryrefslogtreecommitdiff
path: root/examples/main
diff options
context:
space:
mode:
authorAlex Klinkhamer <from.github.com.917@grencez.dev>2023-04-21 11:18:09 -0700
committerGitHub <noreply@github.com>2023-04-21 21:18:09 +0300
commit9411288271ab548216902a029f42a0a38ebcedb7 (patch)
treef480e35b257b61b0325be1e15ee9ff8be0f3c9a3 /examples/main
parent8687c1f2581d059cd5b6a9502f89bd343566062a (diff)
main : evaluate tokens in batches after swapping context (#1014)
* examples : evaluate tokens in batches after swapping context * Update examples/main/main.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/main')
-rw-r--r--examples/main/main.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index b7b3c41..65db792 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
- // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() > n_ctx) {
const int n_left = n_past - params.n_keep;
@@ -282,13 +282,21 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
}
- if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
- fprintf(stderr, "%s : failed to eval\n", __func__);
- return 1;
+ // evaluate tokens in batches
+ // embd is typically prepared beforehand to fit within a batch, but not always
+ for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
+ int n_eval = (int) embd.size() - i;
+ if (n_eval > params.n_batch) {
+ n_eval = params.n_batch;
+ }
+ if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
+ }
+ n_past += n_eval;
}
}
- n_past += embd.size();
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {