aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRon Evans <ron@hybridgroup.com>2023-05-02 22:39:51 +0200
committerGitHub <noreply@github.com>2023-05-02 23:39:51 +0300
commit67c77799e025a8425c23a6a0599c007f46ded590 (patch)
tree4619ab8a7e1ac62079f1f5f912c0022d2c019d13 /examples
parent0e6cbff1b7509628c588e661166f6e187137734d (diff)
examples : add llama_init_from_gpt_params() common function (#1290)
Signed-off-by: deadprogram <ron@hybridgroup.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp31
-rw-r--r--examples/common.h6
-rw-r--r--examples/embedding/embedding.cpp22
-rw-r--r--examples/main/main.cpp33
-rw-r--r--examples/perplexity/perplexity.cpp35
5 files changed, 51 insertions, 76 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 2bf0dc5..9b23b1f 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
+ auto lparams = llama_context_default_params();
+
+ lparams.n_ctx = params.n_ctx;
+ lparams.n_parts = params.n_parts;
+ lparams.seed = params.seed;
+ lparams.f16_kv = params.memory_f16;
+ lparams.use_mmap = params.use_mmap;
+ lparams.use_mlock = params.use_mlock;
+
+ llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
+
+ if (lctx == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return NULL;
+ }
+
+ if (!params.lora_adapter.empty()) {
+ int err = llama_apply_lora_from_file(lctx,
+ params.lora_adapter.c_str(),
+ params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ return NULL;
+ }
+ }
+
+ return lctx;
+}
+
/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
diff --git a/examples/common.h b/examples/common.h
index 627696e..138d0de 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -78,6 +78,12 @@ std::string gpt_random_prompt(std::mt19937 & rng);
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
//
+// Model utils
+//
+
+struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
+
+//
// Console utils
//
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 1e9d8a8..e4b7291 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -35,24 +35,10 @@ int main(int argc, char ** argv) {
llama_context * ctx;
// load the model
- {
- auto lparams = llama_context_default_params();
-
- lparams.n_ctx = params.n_ctx;
- lparams.n_parts = params.n_parts;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.logits_all = params.perplexity;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.embedding = params.embedding;
-
- ctx = llama_init_from_file(params.model.c_str(), lparams);
-
- if (ctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return 1;
- }
+ ctx = llama_init_from_gpt_params(params);
+ if (ctx == NULL) {
+ fprintf(stderr, "%s: error: unable to load model\n", __func__);
+ return 1;
}
// print system information
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 54836b3..a10256a 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -101,34 +101,11 @@ int main(int argc, char ** argv) {
llama_context * ctx;
g_ctx = &ctx;
- // load the model
- {
- auto lparams = llama_context_default_params();
-
- lparams.n_ctx = params.n_ctx;
- lparams.n_parts = params.n_parts;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
-
- ctx = llama_init_from_file(params.model.c_str(), lparams);
-
- if (ctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return 1;
- }
- }
-
- if (!params.lora_adapter.empty()) {
- int err = llama_apply_lora_from_file(ctx,
- params.lora_adapter.c_str(),
- params.lora_base.empty() ? NULL : params.lora_base.c_str(),
- params.n_threads);
- if (err != 0) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- return 1;
- }
+ // load the model and apply lora adapter, if any
+ ctx = llama_init_from_gpt_params(params);
+ if (ctx == NULL) {
+ fprintf(stderr, "%s: error: unable to load model\n", __func__);
+ return 1;
}
// print system information
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index d474bc5..299a199 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -122,36 +122,11 @@ int main(int argc, char ** argv) {
llama_context * ctx;
- // load the model
- {
- auto lparams = llama_context_default_params();
-
- lparams.n_ctx = params.n_ctx;
- lparams.n_parts = params.n_parts;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.logits_all = params.perplexity;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.embedding = params.embedding;
-
- ctx = llama_init_from_file(params.model.c_str(), lparams);
-
- if (ctx == NULL) {
- fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
- return 1;
- }
- }
-
- if (!params.lora_adapter.empty()) {
- int err = llama_apply_lora_from_file(ctx,
- params.lora_adapter.c_str(),
- params.lora_base.empty() ? NULL : params.lora_base.c_str(),
- params.n_threads);
- if (err != 0) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- return 1;
- }
+ // load the model and apply lora adapter, if any
+ ctx = llama_init_from_gpt_params(params);
+ if (ctx == NULL) {
+ fprintf(stderr, "%s: error: unable to load model\n", __func__);
+ return 1;
}
// print system information