From 975d2cebf97ce888fa0aeee6f5ac774d7135891f Mon Sep 17 00:00:00 2001 From: anzz1 Date: Tue, 21 Mar 2023 17:42:43 +0200 Subject: cmdline option for custom amount of model parts (--n_parts N) (#348) * cmdline option for custom amount of model parts (--n_parts N) * Update main.cpp --------- Co-authored-by: Georgi Gerganov --- main.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'main.cpp') diff --git a/main.cpp b/main.cpp index e97611e..662a2a7 100644 --- a/main.cpp +++ b/main.cpp @@ -90,7 +90,8 @@ struct llama_model { }; // load the model's weights from a file -bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) { + +bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); std::vector f_buf(1024*1024); @@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca } int n_ff = 0; - int n_parts = 0; // load hparams { @@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca hparams.n_ctx = n_ctx; n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; - n_parts = LLAMA_N_PARTS.at(hparams.n_embd); + + if (n_parts < 1) { + n_parts = LLAMA_N_PARTS.at(hparams.n_embd); + } fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx); @@ -839,7 +842,7 @@ int main(int argc, char ** argv) { { const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) { + if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } -- cgit v1.2.3