diff options
author | Didzis Gosko <didzis@users.noreply.github.com> | 2023-06-24 11:47:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-24 11:47:58 +0300 |
commit | 527b6fba1d237befb324fd846bda7418c0fa394d (patch) | |
tree | 360b44abac0c9a53739444b8ba9e4ccf903938cd /examples | |
parent | d7b7484f74d486f77feb4c0b7af7e1718ed91651 (diff) |
llama : make model stateless and context stateful (llama_state) (#1797)
* llama : make model stateless and context stateful
* llama : minor cleanup
* llama : update internal API declaration
* Apply suggestions from code review
fix style
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* Missing model memory release
* Fix style
* Add deprecated warning for public API function llama_init_from_file
* Update public API use cases: move away from deprecated llama_init_from_file
* Deprecate public API function llama_apply_lora_from_file
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/common.cpp | 22 | ||||
-rw-r--r-- | examples/common.h | 3 | ||||
-rw-r--r-- | examples/embedding/embedding.cpp | 6 | ||||
-rw-r--r-- | examples/main/main.cpp | 8 | ||||
-rw-r--r-- | examples/perplexity/perplexity.cpp | 6 | ||||
-rw-r--r-- | examples/quantize-stats/quantize-stats.cpp | 15 | ||||
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 29 | ||||
-rw-r--r-- | examples/server/server.cpp | 9 | ||||
-rw-r--r-- | examples/simple/simple.cpp | 8 | ||||
-rw-r--r-- | examples/train-text-from-scratch/train-text-from-scratch.cpp | 5 |
10 files changed, 85 insertions, 26 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index fed24e0..6ac4845 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s return res; } -struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { +std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; @@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; - llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + llama_context * lctx = llama_new_context_with_model(model, lparams); if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return NULL; + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(lctx, + int err = llama_model_apply_lora_from_file(model, params.lora_adapter.c_str(), params.lora_base.empty() ? NULL : params.lora_base.c_str(), params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return NULL; + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); } } - return lctx; + return std::make_tuple(model, lctx); } void console_init(console_state & con_st) { diff --git a/examples/common.h b/examples/common.h index 6c2953c..7133201 100644 --- a/examples/common.h +++ b/examples/common.h @@ -9,6 +9,7 @@ #include <random> #include <thread> #include <unordered_map> +#include <tuple> #if !defined (_WIN32) #include <stdio.h> @@ -95,7 +96,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s // Model utils // -struct llama_context * llama_init_from_gpt_params(const gpt_params & params); +std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); // // Console utils diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 860f99f..369eac1 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -37,11 +37,12 @@ int main(int argc, char ** argv) { llama_init_backend(); + llama_model * model; llama_context * ctx; // load the model - ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) { + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; } @@ -90,6 +91,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + llama_free_model(model); return 0; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 941312f..c1e6bf1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -107,12 +107,13 @@ int main(int argc, char ** argv) { llama_init_backend(); + llama_model * model; llama_context * ctx; g_ctx = &ctx; // load the model and apply lora adapter, if any - ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) { + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; } @@ -139,6 +140,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + llama_free_model(model); return 0; } @@ -147,6 +149,7 @@ int main(int argc, char ** argv) { if (params.export_cgraph) { llama_eval_export(ctx, "llama.ggml"); llama_free(ctx); + llama_free_model(model); return 0; } @@ -666,6 +669,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + llama_free_model(model); return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ae8cfe0..b59f597 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -149,11 +149,12 @@ int main(int argc, char ** argv) { llama_init_backend(); + llama_model * model; llama_context * ctx; // load the model and apply lora adapter, if any - ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) { + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; } @@ -169,6 +170,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + llama_free_model(model); return 0; } diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 6b8018e..9cea472 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -320,6 +320,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "Loading model\n"); const int64_t t_main_start_us = ggml_time_us(); + llama_model * model; llama_context * ctx; { @@ -330,12 +331,20 @@ int main(int argc, char ** argv) { lparams.f16_kv = false; lparams.use_mlock = false; - ctx = llama_init_from_file(params.model.c_str(), lparams); + model = llama_load_model_from_file(params.model.c_str(), lparams); - if (ctx == NULL) { + if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return 1; } + + ctx = llama_new_context_with_model(model, lparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return 1; + } } const auto &tensors = llama_internal_get_tensor_map(ctx); @@ -357,6 +366,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: Quantization should be tested with a float model, " "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); llama_free(ctx); + llama_free_model(model); return 1; } included_layers++; @@ -415,6 +425,7 @@ int main(int argc, char ** argv) { llama_free(ctx); + llama_free_model(model); // report timing { const int64_t t_main_end_us = ggml_time_us(); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index da4d37a..4c86885 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -35,12 +35,22 @@ int main(int argc, char ** argv) { auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); // init - auto ctx = llama_init_from_file(params.model.c_str(), lparams); + auto model = llama_load_model_from_file(params.model.c_str(), lparams); + if (model == nullptr) { + return 1; + } + auto ctx = llama_new_context_with_model(model, lparams); + if (ctx == nullptr) { + llama_free_model(model); + return 1; + } auto tokens = std::vector<llama_token>(params.n_ctx); auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); if (n_prompt_tokens < 1) { fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); + llama_free(ctx); + llama_free_model(model); return 1; } @@ -84,6 +94,8 @@ int main(int argc, char ** argv) { printf("%s", next_token_str); if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx); + llama_free_model(model); return 1; } n_past += 1; @@ -91,23 +103,27 @@ int main(int argc, char ** argv) { printf("\n\n"); - // free old model + // free old context llama_free(ctx); - // load new model - auto ctx2 = llama_init_from_file(params.model.c_str(), lparams); + // make new context + auto ctx2 = llama_new_context_with_model(model, lparams); // Load state (rng, logits, embedding and kv_cache) from file { FILE *fp_read = fopen("dump_state.bin", "rb"); if (state_size != llama_get_state_size(ctx2)) { fprintf(stderr, "\n%s : failed to validate state size\n", __func__); + llama_free(ctx2); + llama_free_model(model); return 1; } const size_t ret = fread(state_mem, 1, state_size, fp_read); if (ret != state_size) { fprintf(stderr, "\n%s : failed to read state\n", __func__); + llama_free(ctx2); + llama_free_model(model); return 1; } @@ -138,6 +154,8 @@ int main(int argc, char ** argv) { printf("%s", next_token_str); if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx2); + llama_free_model(model); return 1; } n_past += 1; @@ -145,5 +163,8 @@ int main(int argc, char ** argv) { printf("\n\n"); + llama_free(ctx2); + llama_free_model(model); + return 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c0984aa..de22d30 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -115,6 +115,7 @@ struct llama_server_context { std::vector<llama_token> embd; std::vector<llama_token> last_n_tokens; + llama_model * model = nullptr; llama_context * ctx = nullptr; gpt_params params; @@ -130,6 +131,10 @@ struct llama_server_context { llama_free(ctx); ctx = nullptr; } + if (model) { + llama_free_model(model); + model = nullptr; + } } void rewind() { @@ -150,8 +155,8 @@ struct llama_server_context { bool loadModel(const gpt_params & params_) { params = params_; - ctx = llama_init_from_gpt_params(params); - if (ctx == nullptr) { + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr) { LOG_ERROR("unable to load model", { { "model", params_.model } }); return false; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 76f991c..fc45c93 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -68,11 +68,12 @@ int main(int argc, char ** argv) llama_init_backend(); - llama_context * ctx ; + llama_model * model; + llama_context * ctx; - ctx = llama_init_from_gpt_params( params ); + std::tie(model, ctx) = llama_init_from_gpt_params( params ); - if ( ctx == NULL ) + if ( model == NULL ) { fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); return 1; @@ -170,6 +171,7 @@ int main(int argc, char ** argv) } // wend of main loop llama_free( ctx ); + llama_free_model( model ); return 0; } diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 7ec8595..61c829e 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -3054,7 +3054,8 @@ int main(int argc, char ** argv) { struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = true; - struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params); + struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params); + struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); struct llama_vocab vocab; { @@ -3395,6 +3396,8 @@ int main(int argc, char ** argv) { delete[] compute_addr; delete[] compute_buf_0; delete[] compute_buf_1; + llama_free(lctx); + llama_free_model(lmodel); ggml_free(model.ctx); return 0; |