aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-21 13:10:51 +0300
committerGitHub <noreply@github.com>2023-07-21 13:10:51 +0300
commitae178ab46bfd6ecb2422da5dad441a4e2fef8b7e (patch)
tree064a13d048ecd596bbd57bd081c9615aa91ebbf6
parent54e3bc76fed914f8d4a30a7a50c19867cccb1338 (diff)
llama : make tensor_split ptr instead of array (#2272)
-rw-r--r--examples/common.cpp2
-rw-r--r--ggml-cuda.cu3
-rw-r--r--llama.cpp4
-rw-r--r--llama.h3
4 files changed, 8 insertions, 4 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index fd6dbc0..476d565 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -586,7 +586,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_batch = params.n_batch;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
- memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
+ lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index d3054a7..6537897 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2512,6 +2512,9 @@ void ggml_init_cublas() {
}
void ggml_cuda_set_tensor_split(const float * tensor_split) {
+ if (tensor_split == nullptr) {
+ return;
+ }
bool all_zero = true;
for (int i = 0; i < g_device_count; ++i) {
if (tensor_split[i] != 0.0f) {
diff --git a/llama.cpp b/llama.cpp
index 796dfda..23e746d 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -849,7 +849,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512,
/*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
- /*.tensor_split =*/ {0},
+ /*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 10000.0f,
/*.rope_freq_scale =*/ 1.0f,
/*.progress_callback =*/ nullptr,
@@ -1289,7 +1289,7 @@ static bool llama_model_load(
int n_batch,
int n_gpu_layers,
int main_gpu,
- float * tensor_split,
+ const float * tensor_split,
float rope_freq_base,
float rope_freq_scale,
bool low_vram,
diff --git a/llama.h b/llama.h
index b676a38..c565f6a 100644
--- a/llama.h
+++ b/llama.h
@@ -88,7 +88,8 @@ extern "C" {
int32_t n_batch; // prompt processing batch size
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
- float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+
+ const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency