aboutsummaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
authoranzz1 <anzz1@live.com>2023-03-21 17:42:43 +0200
committerGitHub <noreply@github.com>2023-03-21 17:42:43 +0200
commit975d2cebf97ce888fa0aeee6f5ac774d7135891f (patch)
treee578d57ca7ccef7851e1f02dfe15887ff829aec4 /main.cpp
parente0ffc861fae5ac8b40ce973f822d03db02929d36 (diff)
cmdline option for custom amount of model parts (--n_parts N) (#348)
* cmdline option for custom amount of model parts (--n_parts N) * Update main.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp11
1 files changed, 7 insertions, 4 deletions
diff --git a/main.cpp b/main.cpp
index e97611e..662a2a7 100644
--- a/main.cpp
+++ b/main.cpp
@@ -90,7 +90,8 @@ struct llama_model {
};
// load the model's weights from a file
-bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
+
+bool llama_model_load(const std::string & fname, llama_model & model, llama_vocab & vocab, int n_ctx, int n_parts, ggml_type memory_type = GGML_TYPE_F32) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
std::vector<char> f_buf(1024*1024);
@@ -127,7 +128,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
}
int n_ff = 0;
- int n_parts = 0;
// load hparams
{
@@ -145,7 +145,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
hparams.n_ctx = n_ctx;
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
- n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+
+ if (n_parts < 1) {
+ n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
+ }
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
@@ -839,7 +842,7 @@ int main(int argc, char ** argv) {
{
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
const int64_t t_start_us = ggml_time_us();
- if (!llama_model_load(params.model, model, vocab, params.n_ctx, memory_type)) {
+ if (!llama_model_load(params.model, model, vocab, params.n_ctx, params.n_parts, memory_type)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}