aboutsummaryrefslogtreecommitdiff
path: root/examples/main
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-05-01 14:54:59 +0300
committerGitHub <noreply@github.com>2023-05-01 14:54:59 +0300
commit70269cae37538461ff816e714afbb3ebcdcdc26b (patch)
tree448ee5f5f8e93816fc7e60c7b05b29668536884a /examples/main
parentb925f1f1b082319ee69943f8d1a83ac9b6ff09ca (diff)
llama : fix session load / save (#1263)
Diffstat (limited to 'examples/main')
-rw-r--r--examples/main/main.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 990d0fa..78fc9a1 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -161,23 +161,22 @@ int main(int argc, char ** argv) {
std::vector<llama_token> session_tokens;
if (!path_session.empty()) {
- fprintf(stderr, "%s: attempting to load saved session from %s..\n", __func__, path_session.c_str());
+ fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
- // REVIEW - fopen to check for existing session
+ // fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(params.n_ctx);
size_t n_token_count_out = 0;
- const size_t n_session_bytes = llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
+ if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
+ fprintf(stderr, "%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
+ return 1;
+ }
session_tokens.resize(n_token_count_out);
- if (n_session_bytes > 0) {
- fprintf(stderr, "%s: loaded %zu bytes of session data!\n", __func__, n_session_bytes);
- } else {
- fprintf(stderr, "%s: could not load session file, will recreate\n", __func__);
- }
+ fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
@@ -214,7 +213,7 @@ int main(int argc, char ** argv) {
}
// number of tokens to keep when resetting context
- if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
+ if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
}
@@ -329,7 +328,7 @@ int main(int argc, char ** argv) {
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
- // REVIEW - stop saving session if we run out of context
+ // stop saving session if we run out of context
path_session = "";
//printf("\n---\n");
@@ -355,6 +354,7 @@ int main(int argc, char ** argv) {
n_session_consumed++;
if (n_session_consumed >= (int) session_tokens.size()) {
+ ++i;
break;
}
}