aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-23 15:09:47 +0300
committerGitHub <noreply@github.com>2023-07-23 15:09:47 +0300
commite76d630df17e235e6b9ef416c45996765d2e36fb (patch)
tree15e0e9648f9b0e398b43e888216a73f84098ff3a /examples
parent1d0824b2476e7fda09751a0235c9e571b76d6f2c (diff)
llama : grouped-query attention + LLaMAv2 70B support (#2276)
* CUDA: GQA implementation * llama : support for GQA and LLaMAv2 70B ggml-ci * py : fix hparams parsing (if-else blocks) ggml-ci * py : oh boy .. ggml-ci * help : fix gqa value for 70B ggml-ci --------- Co-authored-by: JohannesGaessler <johannesg@5d6.de>
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp12
-rw-r--r--examples/common.h3
-rw-r--r--examples/main/main.cpp4
3 files changed, 14 insertions, 5 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 6610397..5608ca8 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "-gqa" || arg == "--gqa") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_gqa = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -f FNAME, --file FNAME\n");
fprintf(stdout, " prompt file to start generation.\n");
fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
+ fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -505,7 +514,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
- fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -513,7 +521,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
- fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n");
fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
+ lparams.n_gqa = params.n_gqa;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
diff --git a/examples/common.h b/examples/common.h
index c936de6..fb8f6d6 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -27,6 +27,7 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
@@ -47,7 +48,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
- int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 4b4cd1d..3bd8ba2 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
}
if (params.n_ctx > 2048) {
- fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
- " you are on your own\n", __func__, params.n_ctx);
+ // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
+ fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;