aboutsummaryrefslogtreecommitdiff
path: root/examples/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/common.cpp')
-rw-r--r--examples/common.cpp22
1 files changed, 15 insertions, 7 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index fed24e0..6ac4845 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
-struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
- llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return std::make_tuple(nullptr, nullptr);
+ }
+ llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return NULL;
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
if (!params.lora_adapter.empty()) {
- int err = llama_apply_lora_from_file(lctx,
+ int err = llama_model_apply_lora_from_file(model,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- return NULL;
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
}
}
- return lctx;
+ return std::make_tuple(model, lctx);
}
void console_init(console_state & con_st) {