diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2023-06-06 21:33:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-06 21:33:23 +0200 |
commit | 17366df842e358768c0df7024484fffecfc7865b (patch) | |
tree | f042c8142311d45f8712db10debf89111b2c7e57 /llama.h | |
parent | 44f906e8537fcec965e312d621c80556d6aa9bec (diff) |
Multi GPU support, CUDA refactor, CUDA scratch buffer (#1703)
* CUDA multi GPU + scratch
ggml_cuda_compute_forward
Tensor parallelism
ggml_cuda_add
ggml_cuda_rms_norm
ggml_cuda_silu
CUDA scratch buffer
--main-gpu CLI option
Diffstat (limited to 'llama.h')
-rw-r--r-- | llama.h | 16 |
1 files changed, 13 insertions, 3 deletions
@@ -1,6 +1,13 @@ #ifndef LLAMA_H #define LLAMA_H +#include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES +#else +#define LLAMA_MAX_DEVICES 1 +#endif // GGML_USE_CUBLAS #include <stddef.h> #include <stdint.h> #include <stdbool.h> @@ -65,9 +72,12 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { - int n_ctx; // text context - int n_gpu_layers; // number of layers to store in VRAM - int seed; // RNG seed, -1 for random + int n_ctx; // text context + int n_batch; // prompt processing batch size + int n_gpu_layers; // number of layers to store in VRAM + int main_gpu; // the GPU that is used for scratch and small tensors + float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + int seed; // RNG seed, -1 for random bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one |