From 70269cae37538461ff816e714afbb3ebcdcdc26b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 1 May 2023 14:54:59 +0300 Subject: llama : fix session load / save (#1263) --- examples/main/main.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'examples/main') diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 990d0fa..78fc9a1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -161,23 +161,22 @@ int main(int argc, char ** argv) { std::vector session_tokens; if (!path_session.empty()) { - fprintf(stderr, "%s: attempting to load saved session from %s..\n", __func__, path_session.c_str()); + fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); - // REVIEW - fopen to check for existing session + // fopen to check for existing session FILE * fp = std::fopen(path_session.c_str(), "rb"); if (fp != NULL) { std::fclose(fp); session_tokens.resize(params.n_ctx); size_t n_token_count_out = 0; - const size_t n_session_bytes = llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out); + if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); + return 1; + } session_tokens.resize(n_token_count_out); - if (n_session_bytes > 0) { - fprintf(stderr, "%s: loaded %zu bytes of session data!\n", __func__, n_session_bytes); - } else { - fprintf(stderr, "%s: could not load session file, will recreate\n", __func__); - } + fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); } else { fprintf(stderr, "%s: session file does not exist, will create\n", __func__); } @@ -214,7 +213,7 @@ int main(int argc, char ** argv) { } // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) { + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) { params.n_keep = (int)embd_inp.size(); } @@ -329,7 +328,7 @@ int main(int argc, char ** argv) { // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); - // REVIEW - stop saving session if we run out of context + // stop saving session if we run out of context path_session = ""; //printf("\n---\n"); @@ -355,6 +354,7 @@ int main(int argc, char ** argv) { n_session_consumed++; if (n_session_consumed >= (int) session_tokens.size()) { + ++i; break; } } -- cgit v1.2.3