aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorDidzis Gosko <didzis@users.noreply.github.com>2023-06-24 11:47:58 +0300
committerGitHub <noreply@github.com>2023-06-24 11:47:58 +0300
commit527b6fba1d237befb324fd846bda7418c0fa394d (patch)
tree360b44abac0c9a53739444b8ba9e4ccf903938cd /examples
parentd7b7484f74d486f77feb4c0b7af7e1718ed91651 (diff)
llama : make model stateless and context stateful (llama_state) (#1797)
* llama : make model stateless and context stateful * llama : minor cleanup * llama : update internal API declaration * Apply suggestions from code review fix style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Missing model memory release * Fix style * Add deprecated warning for public API function llama_init_from_file * Update public API use cases: move away from deprecated llama_init_from_file * Deprecate public API function llama_apply_lora_from_file --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp22
-rw-r--r--examples/common.h3
-rw-r--r--examples/embedding/embedding.cpp6
-rw-r--r--examples/main/main.cpp8
-rw-r--r--examples/perplexity/perplexity.cpp6
-rw-r--r--examples/quantize-stats/quantize-stats.cpp15
-rw-r--r--examples/save-load-state/save-load-state.cpp29
-rw-r--r--examples/server/server.cpp9
-rw-r--r--examples/simple/simple.cpp8
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp5
10 files changed, 85 insertions, 26 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index fed24e0..6ac4845 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
-struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
- llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return std::make_tuple(nullptr, nullptr);
+ }
+ llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return NULL;
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
if (!params.lora_adapter.empty()) {
- int err = llama_apply_lora_from_file(lctx,
+ int err = llama_model_apply_lora_from_file(model,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- return NULL;
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
}
- return lctx;
+ return std::make_tuple(model, lctx);
}
void console_init(console_state & con_st) {
diff --git a/examples/common.h b/examples/common.h
index 6c2953c..7133201 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -9,6 +9,7 @@
#include <random>
#include <thread>
#include <unordered_map>
+#include <tuple>
#if !defined (_WIN32)
#include <stdio.h>
@@ -95,7 +96,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// Model utils
//
-struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
//
// Console utils
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 860f99f..369eac1 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -37,11 +37,12 @@ int main(int argc, char ** argv) {
llama_init_backend();
+ llama_model * model;
llama_context * ctx;
// load the model
- ctx = llama_init_from_gpt_params(params);
- if (ctx == NULL) {
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
@@ -90,6 +91,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
+ llama_free_model(model);
return 0;
}
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 941312f..c1e6bf1 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -107,12 +107,13 @@ int main(int argc, char ** argv) {
llama_init_backend();
+ llama_model * model;
llama_context * ctx;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
- ctx = llama_init_from_gpt_params(params);
- if (ctx == NULL) {
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
@@ -139,6 +140,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
+ llama_free_model(model);
return 0;
}
@@ -147,6 +149,7 @@ int main(int argc, char ** argv) {
if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml");
llama_free(ctx);
+ llama_free_model(model);
return 0;
}
@@ -666,6 +669,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
+ llama_free_model(model);
return 0;
}
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index ae8cfe0..b59f597 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -149,11 +149,12 @@ int main(int argc, char ** argv) {
llama_init_backend();
+ llama_model * model;
llama_context * ctx;
// load the model and apply lora adapter, if any
- ctx = llama_init_from_gpt_params(params);
- if (ctx == NULL) {
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
@@ -169,6 +170,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
+ llama_free_model(model);
return 0;
}
diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp
index 6b8018e..9cea472 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/examples/quantize-stats/quantize-stats.cpp
@@ -320,6 +320,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Loading model\n");
const int64_t t_main_start_us = ggml_time_us();
+ llama_model * model;
llama_context * ctx;
{
@@ -330,12 +331,20 @@ int main(int argc, char ** argv) {
lparams.f16_kv = false;
lparams.use_mlock = false;
- ctx = llama_init_from_file(params.model.c_str(), lparams);
+ model = llama_load_model_from_file(params.model.c_str(), lparams);
- if (ctx == NULL) {
+ if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
+
+ ctx = llama_new_context_with_model(model, lparams);
+
+ if (ctx == NULL) {
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return 1;
+ }
}
const auto &tensors = llama_internal_get_tensor_map(ctx);
@@ -357,6 +366,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: error: Quantization should be tested with a float model, "
"this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type);
llama_free(ctx);
+ llama_free_model(model);
return 1;
}
included_layers++;
@@ -415,6 +425,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
+ llama_free_model(model);
// report timing
{
const int64_t t_main_end_us = ggml_time_us();
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index da4d37a..4c86885 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto ctx = llama_init_from_file(params.model.c_str(), lparams);
+ auto model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == nullptr) {
+ return 1;
+ }
+ auto ctx = llama_new_context_with_model(model, lparams);
+ if (ctx == nullptr) {
+ llama_free_model(model);
+ return 1;
+ }
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
@@ -84,6 +94,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -91,23 +103,27 @@ int main(int argc, char ** argv) {
printf("\n\n");
- // free old model
+ // free old context
llama_free(ctx);
- // load new model
- auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
+ // make new context
+ auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
@@ -138,6 +154,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx2);
+ llama_free_model(model);
return 1;
}
n_past += 1;
@@ -145,5 +163,8 @@ int main(int argc, char ** argv) {
printf("\n\n");
+ llama_free(ctx2);
+ llama_free_model(model);
+
return 0;
}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index c0984aa..de22d30 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -115,6 +115,7 @@ struct llama_server_context {
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
+ llama_model * model = nullptr;
llama_context * ctx = nullptr;
gpt_params params;
@@ -130,6 +131,10 @@ struct llama_server_context {
llama_free(ctx);
ctx = nullptr;
}
+ if (model) {
+ llama_free_model(model);
+ model = nullptr;
+ }
}
void rewind() {
@@ -150,8 +155,8 @@ struct llama_server_context {
bool loadModel(const gpt_params & params_) {
params = params_;
- ctx = llama_init_from_gpt_params(params);
- if (ctx == nullptr) {
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (model == nullptr) {
LOG_ERROR("unable to load model", { { "model", params_.model } });
return false;
}
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp
index 76f991c..fc45c93 100644
--- a/examples/simple/simple.cpp
+++ b/examples/simple/simple.cpp
@@ -68,11 +68,12 @@ int main(int argc, char ** argv)
llama_init_backend();
- llama_context * ctx ;
+ llama_model * model;
+ llama_context * ctx;
- ctx = llama_init_from_gpt_params( params );
+ std::tie(model, ctx) = llama_init_from_gpt_params( params );
- if ( ctx == NULL )
+ if ( model == NULL )
{
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1;
@@ -170,6 +171,7 @@ int main(int argc, char ** argv)
} // wend of main loop
llama_free( ctx );
+ llama_free_model( model );
return 0;
}
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 7ec8595..61c829e 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -3054,7 +3054,8 @@ int main(int argc, char ** argv) {
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
- struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params);
+ struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
+ struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct llama_vocab vocab;
{
@@ -3395,6 +3396,8 @@ int main(int argc, char ** argv) {
delete[] compute_addr;
delete[] compute_buf_0;
delete[] compute_buf_1;
+ llama_free(lctx);
+ llama_free_model(lmodel);
ggml_free(model.ctx);
return 0;