diff options
author | Ron Evans <ron@hybridgroup.com> | 2023-05-02 22:39:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-02 23:39:51 +0300 |
commit | 67c77799e025a8425c23a6a0599c007f46ded590 (patch) | |
tree | 4619ab8a7e1ac62079f1f5f912c0022d2c019d13 /examples | |
parent | 0e6cbff1b7509628c588e661166f6e187137734d (diff) |
examples : add llama_init_from_gpt_params() common function (#1290)
Signed-off-by: deadprogram <ron@hybridgroup.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/common.cpp | 31 | ||||
-rw-r--r-- | examples/common.h | 6 | ||||
-rw-r--r-- | examples/embedding/embedding.cpp | 22 | ||||
-rw-r--r-- | examples/main/main.cpp | 33 | ||||
-rw-r--r-- | examples/perplexity/perplexity.cpp | 35 |
5 files changed, 51 insertions, 76 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index 2bf0dc5..9b23b1f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -405,6 +405,37 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s return res; } +struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { + auto lparams = llama_context_default_params(); + + lparams.n_ctx = params.n_ctx; + lparams.n_parts = params.n_parts; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + + llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return NULL; + } + + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(lctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return NULL; + } + } + + return lctx; +} + /* Keep track of current color of output, and emit ANSI code if it changes. */ void set_console_color(console_state & con_st, console_color_t color) { if (con_st.use_color && con_st.color != color) { diff --git a/examples/common.h b/examples/common.h index 627696e..138d0de 100644 --- a/examples/common.h +++ b/examples/common.h @@ -78,6 +78,12 @@ std::string gpt_random_prompt(std::mt19937 & rng); std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); // +// Model utils +// + +struct llama_context * llama_init_from_gpt_params(const gpt_params & params); + +// // Console utils // diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 1e9d8a8..e4b7291 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,24 +35,10 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.logits_all = params.perplexity; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.embedding = params.embedding; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 54836b3..a10256a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -101,34 +101,11 @@ int main(int argc, char ** argv) { llama_context * ctx; g_ctx = &ctx; - // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } - } - - if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, - params.lora_adapter.c_str(), - params.lora_base.empty() ? NULL : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return 1; - } + // load the model and apply lora adapter, if any + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index d474bc5..299a199 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -122,36 +122,11 @@ int main(int argc, char ** argv) { llama_context * ctx; - // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.logits_all = params.perplexity; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.embedding = params.embedding; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } - } - - if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, - params.lora_adapter.c_str(), - params.lora_base.empty() ? NULL : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return 1; - } + // load the model and apply lora adapter, if any + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information |