aboutsummaryrefslogtreecommitdiff
path: root/examples/save-load-state
diff options
context:
space:
mode:
Diffstat (limited to 'examples/save-load-state')
-rw-r--r--examples/save-load-state/save-load-state.cpp29
1 files changed, 25 insertions, 4 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index da4d37a..4c86885 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto ctx = llama_init_from_file(params.model.c_str(), lparams);
+ auto model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == nullptr) {
+ return 1;
+ }
+ auto ctx = llama_new_context_with_model(model, lparams);
+ if (ctx == nullptr) {
+ llama_free_model(model);
+ return 1;
+ }
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
@@ -84,6 +94,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -91,23 +103,27 @@ int main(int argc, char ** argv) {
printf("\n\n");
- // free old model
+ // free old context
llama_free(ctx);
- // load new model
- auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
+ // make new context
+ auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
@@ -138,6 +154,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -145,5 +163,8 @@ int main(int argc, char ** argv) {
printf("\n\n");
+ llama_free(ctx2);
+ llama_free_model(model);
+
return 0;
}