diff options
| author | Alex Klinkhamer <from.github.com.917@grencez.dev> | 2023-04-21 11:18:09 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-21 21:18:09 +0300 | 
| commit | 9411288271ab548216902a029f42a0a38ebcedb7 (patch) | |
| tree | f480e35b257b61b0325be1e15ee9ff8be0f3c9a3 /examples/main | |
| parent | 8687c1f2581d059cd5b6a9502f89bd343566062a (diff) | |
main : evaluate tokens in batches after swapping context (#1014)
* examples : evaluate tokens in batches after swapping context
* Update examples/main/main.cpp
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/main')
| -rw-r--r-- | examples/main/main.cpp | 18 | 
1 files changed, 13 insertions, 5 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b7b3c41..65db792 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,7 +264,7 @@ int main(int argc, char ** argv) {              // infinite text generation via context swapping              // if we run out of context:              // - take the n_keep first tokens from the original prompt (via n_past) -            // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch +            // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches              if (n_past + (int) embd.size() > n_ctx) {                  const int n_left = n_past - params.n_keep; @@ -282,13 +282,21 @@ int main(int argc, char ** argv) {                  //printf("\n---\n");              } -            if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { -                fprintf(stderr, "%s : failed to eval\n", __func__); -                return 1; +            // evaluate tokens in batches +            // embd is typically prepared beforehand to fit within a batch, but not always +            for (int i = 0; i < (int) embd.size(); i += params.n_batch) { +                int n_eval = (int) embd.size() - i; +                if (n_eval > params.n_batch) { +                    n_eval = params.n_batch; +                } +                if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { +                    fprintf(stderr, "%s : failed to eval\n", __func__); +                    return 1; +                } +                n_past += n_eval;              }          } -        n_past += embd.size();          embd.clear();          if ((int) embd_inp.size() <= n_consumed && !is_interacting) {  | 
