aboutsummaryrefslogtreecommitdiff
path: root/examples/main
diff options
context:
space:
mode:
authorEvan Jones <evan.q.jones@gmail.com>2023-05-10 11:37:14 -0400
committerGitHub <noreply@github.com>2023-05-10 11:37:14 -0400
commitcf348a60e0af3905acd1d297cb064b918265b7ac (patch)
treeb5480b47918c0d1f386db71a195028fd5ca095be /examples/main
parente6a46b0ed1884c77267dc70693183e3b7164e0e0 (diff)
main : add option to save full output to session (#1338)
* main : add option to save full output to session * split behavior into --session and --prompt-cache * restore original implementation with new names * PR comments * move the check for incompatible parameters to gpt_params_parse * Fix whitespace Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com> --------- Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com>
Diffstat (limited to 'examples/main')
-rw-r--r--examples/main/README.md4
-rw-r--r--examples/main/main.cpp20
2 files changed, 12 insertions, 12 deletions
diff --git a/examples/main/README.md b/examples/main/README.md
index 35f87bc..7c03f92 100644
--- a/examples/main/README.md
+++ b/examples/main/README.md
@@ -270,9 +270,9 @@ These options help improve the performance and memory usage of the LLaMA models.
- `-b N, --batch_size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
-### Session Caching
+### Prompt Caching
-- `--session FNAME`: Specify a file to load/save the session, which caches the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The session file is created during the first run and is reused in subsequent runs. If you change your prompt such that 75% or less of the session is reusable, the existing session file will be overwritten with a new, updated version to maintain optimal performance.
+- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
### Quantization
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 6e1172a..bd1c4ab 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -139,7 +139,7 @@ int main(int argc, char ** argv) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
- std::string path_session = params.path_session;
+ std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;
if (!path_session.empty()) {
@@ -292,14 +292,9 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first;
}
- bool is_antiprompt = false;
- bool input_echo = true;
-
- // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
- // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
- // initial prompt so it doesn't need to be an exact match.
- bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
-
+ bool is_antiprompt = false;
+ bool input_echo = true;
+ bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
int n_past = 0;
int n_remain = params.n_predict;
@@ -328,7 +323,7 @@ int main(int argc, char ** argv) {
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
// stop saving session if we run out of context
- path_session = "";
+ path_session.clear();
//printf("\n---\n");
//printf("resetting: '");
@@ -603,6 +598,11 @@ int main(int argc, char ** argv) {
}
}
+ if (!path_session.empty() && params.prompt_cache_all) {
+ fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
+ llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+ }
+
llama_print_timings(ctx);
llama_free(ctx);