aboutsummaryrefslogtreecommitdiff
path: root/examples/save-load-state
diff options
context:
space:
mode:
authorDidzis Gosko <didzis@users.noreply.github.com>2023-06-24 11:47:58 +0300
committerGitHub <noreply@github.com>2023-06-24 11:47:58 +0300
commit527b6fba1d237befb324fd846bda7418c0fa394d (patch)
tree360b44abac0c9a53739444b8ba9e4ccf903938cd /examples/save-load-state
parentd7b7484f74d486f77feb4c0b7af7e1718ed91651 (diff)
llama : make model stateless and context stateful (llama_state) (#1797)
* llama : make model stateless and context stateful * llama : minor cleanup * llama : update internal API declaration * Apply suggestions from code review fix style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Missing model memory release * Fix style * Add deprecated warning for public API function llama_init_from_file * Update public API use cases: move away from deprecated llama_init_from_file * Deprecate public API function llama_apply_lora_from_file --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/save-load-state')
-rw-r--r--examples/save-load-state/save-load-state.cpp29
1 files changed, 25 insertions, 4 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index da4d37a..4c86885 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto ctx = llama_init_from_file(params.model.c_str(), lparams);
+ auto model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == nullptr) {
+ return 1;
+ }
+ auto ctx = llama_new_context_with_model(model, lparams);
+ if (ctx == nullptr) {
+ llama_free_model(model);
+ return 1;
+ }
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
@@ -84,6 +94,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -91,23 +103,27 @@ int main(int argc, char ** argv) {
printf("\n\n");
- // free old model
+ // free old context
llama_free(ctx);
- // load new model
- auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
+ // make new context
+ auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
@@ -138,6 +154,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -145,5 +163,8 @@ int main(int argc, char ** argv) {
printf("\n\n");
+ llama_free(ctx2);
+ llama_free_model(model);
+
return 0;
}