diff options
author | slaren <2141330+slaren@users.noreply.github.com> | 2023-04-17 17:28:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-17 17:28:55 +0200 |
commit | 315a95a4d30db726fb7d244dd3b9e90a83fb1616 (patch) | |
tree | 569d8140cde36ad971d3d3120556ab5533603931 /examples | |
parent | efd05648c88a0923a55f56e7ce1b0f9c33410afb (diff) |
Add LoRA support (#820)
Diffstat (limited to 'examples')
-rw-r--r-- | examples/common.cpp | 15 | ||||
-rw-r--r-- | examples/common.h | 7 | ||||
-rw-r--r-- | examples/main/main.cpp | 11 | ||||
-rw-r--r-- | examples/perplexity/perplexity.cpp | 11 |
4 files changed, 41 insertions, 3 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index 0772dbf..a0b6f10 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -139,6 +139,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter = argv[i]; + params.use_mmap = false; + } else if (arg == "--lora-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -242,6 +255,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { } fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index 1ea6f74..cbbc2df 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,11 +31,12 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string input_prefix = ""; // string to prefix user inputs with - - + std::string input_prefix = ""; // string to prefix user inputs with std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted + std::string lora_adapter = ""; // lora adapter path + std::string lora_base = ""; // base model path for the lora adapter + bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e4b003..b7b3c41 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -114,6 +114,17 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 19449e1..80792ea 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -134,6 +134,17 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); |