diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-07-23 15:09:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 15:09:47 +0300 |
commit | e76d630df17e235e6b9ef416c45996765d2e36fb (patch) | |
tree | 15e0e9648f9b0e398b43e888216a73f84098ff3a /llama.cpp | |
parent | 1d0824b2476e7fda09751a0235c9e571b76d6f2c (diff) |
llama : grouped-query attention + LLaMAv2 70B support (#2276)
* CUDA: GQA implementation
* llama : support for GQA and LLaMAv2 70B
ggml-ci
* py : fix hparams parsing (if-else blocks)
ggml-ci
* py : oh boy ..
ggml-ci
* help : fix gqa value for 70B
ggml-ci
---------
Co-authored-by: JohannesGaessler <johannesg@5d6.de>
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 156 |
1 files changed, 105 insertions, 51 deletions
@@ -67,6 +67,7 @@ enum e_model { MODEL_13B, MODEL_30B, MODEL_65B, + MODEL_70B, }; static const size_t kB = 1024; @@ -109,6 +110,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx) { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess + { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB }, }; return k_sizes; } @@ -121,6 +123,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1() { MODEL_13B, 192ull * MB }, { MODEL_30B, 256ull * MB }, { MODEL_65B, 384ull * MB }, // guess + { MODEL_70B, 304ull * MB }, }; return k_sizes; } @@ -134,6 +137,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL() { MODEL_13B, 12ull * MB }, { MODEL_30B, 16ull * MB }, { MODEL_65B, 24ull * MB }, // guess + { MODEL_70B, 24ull * MB }, }; return k_sizes; } @@ -148,6 +152,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE() { MODEL_13B, 640ull * kB }, { MODEL_30B, 768ull * kB }, { MODEL_65B, 1536ull * kB }, + { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced) }; return k_sizes; } @@ -162,19 +167,25 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT() { MODEL_13B, 160ull }, { MODEL_30B, 208ull }, { MODEL_65B, 416ull }, + { MODEL_70B, 416ull }, // TODO (likely can be reduced) }; return k_sizes; } // default hparams (LLaMA 7B) struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 256; - uint32_t n_head = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_mult = 256; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + // LLaMAv2 + // TODO: load from model data hparams + float f_ffn_mult = 1.0f; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; @@ -182,12 +193,24 @@ struct llama_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; bool operator!=(const llama_hparams & other) const { - return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); + return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + } + + uint32_t n_gqa() const { + return n_head/n_head_kv; + } + + uint32_t n_embd_head() const { + return n_embd/n_head; + } + + uint32_t n_embd_gqa() const { + return n_embd/n_gqa(); } size_t kv_size() const { size_t result = 2ull; - result *= (size_t) n_embd; + result *= (size_t) n_embd_gqa(); result *= (size_t) n_ctx; result *= (size_t) n_layer; result *= sizeof(ggml_fp16_t); @@ -493,12 +516,16 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); - hparams.n_embd = file.read_u32(); - hparams.n_mult = file.read_u32(); - hparams.n_head = file.read_u32(); + hparams.n_embd = file.read_u32(); + hparams.n_mult = file.read_u32(); + hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); - hparams.ftype = (enum llama_ftype) file.read_u32(); + hparams.n_rot = file.read_u32(); + hparams.ftype = (enum llama_ftype) file.read_u32(); + + // LLaMAv2 + // TODO: read from header + hparams.n_head_kv = hparams.n_head; } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -797,7 +824,7 @@ static bool kv_cache_init( ggml_type wtype, int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd; + const int n_embd = hparams.n_embd_gqa(); const int n_layer = hparams.n_layer; const int64_t n_mem = n_layer*n_ctx; @@ -841,6 +868,7 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_gqa =*/ 1, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -960,6 +988,7 @@ static const char *llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_30B: return "30B"; case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; default: LLAMA_ASSERT(false); } } @@ -970,6 +999,7 @@ static void llama_model_load_internal( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -991,6 +1021,7 @@ static void llama_model_load_internal( model.hparams = ml->file_loader->hparams; model.n_gpu_layers = n_gpu_layers; llama_file_version file_version = ml->file_loader->file_version; + auto & hparams = model.hparams; { @@ -1010,11 +1041,25 @@ static void llama_model_load_internal( hparams.n_ctx = n_ctx; + // LLaMAv2 + // TODO: temporary until GGUF + LLAMA_ASSERT(hparams.n_head % n_gqa == 0); + hparams.n_head_kv = hparams.n_head / n_gqa; + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model + } + hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; } - const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 + const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3; + const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw; + const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + //const uint32_t n_ff = 28672; { fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); @@ -1023,12 +1068,14 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } @@ -1098,9 +1145,10 @@ static void llama_model_load_internal( size_t vram_weights = 0; size_t vram_scratch = 0; { - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd_gqa = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; ml->ggml_ctx = ctx; @@ -1148,16 +1196,16 @@ static void llama_model_load_internal( layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); + layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); + layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1281,6 +1329,7 @@ static bool llama_model_load( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1294,7 +1343,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1338,17 +1387,22 @@ static bool llama_eval_internal( LLAMA_ASSERT(!!kv_self.ctx); - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - const int n_rot = hparams.n_embd/hparams.n_head; - const int n_gpu_layers = model.n_gpu_layers; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + LLAMA_ASSERT(n_embd_head == hparams.n_rot); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const int n_gpu_layers = model.n_gpu_layers; + auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; @@ -1446,11 +1500,11 @@ static bool llama_eval_internal( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -1462,17 +1516,17 @@ static bool llama_eval_internal( offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -1491,8 +1545,8 @@ static bool llama_eval_internal( struct ggml_tensor * K = ggml_permute(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), - n_embd/n_head, n_head, n_past + N), + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa), + n_embd_head, n_head_kv, n_past + N), 0, 2, 1, 3); offload_func_kq(K); ggml_set_name(K, "K"); @@ -1502,9 +1556,9 @@ static bool llama_eval_internal( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled = KQ / sqrt(n_embd_head) struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); - ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_scaled shape [n_past + N, N, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); @@ -1524,10 +1578,10 @@ static bool llama_eval_internal( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd/n_head, n_head, + n_past + N, n_embd_head, n_head_kv, n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + n_ctx*ggml_element_size(kv_self.v)*n_embd_head, + n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); @@ -1539,7 +1593,7 @@ static bool llama_eval_internal( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -2693,7 +2747,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { |