aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-04-05 22:07:33 +0300
committerGitHub <noreply@github.com>2023-04-05 22:07:33 +0300
commit986b6ce9f99503c51ec5afd8a10baa32359434c6 (patch)
treef4655b45b130b908729eb1407ca9e016c05f21a4 /llama.cpp
parent34162989297fdfe3ab7305451ce55bc87e3f4c9c (diff)
ggml, llama : avoid heavy V transpose + improvements (#775)
ggml : - added ggml_view_3d() - ggml_view_tensor() now inherits the stride too - reimplement ggml_cpy() to account for dst stride - no longer require tensor->data to be memory aligned llama : - compute RoPE on 32-bit tensors (should be more accurate) - store RoPE-ed K in the KV cache - store transposed V in the KV cache (significant speed-up) - avoid unnecessary Q copy
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp67
1 files changed, 37 insertions, 30 deletions
diff --git a/llama.cpp b/llama.cpp
index e451795..581a839 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -810,37 +810,35 @@ static bool llama_eval_internal(
// self-attention
{
- struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
- struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ // compute Q and K and RoPE them
+ struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
+ struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
// store key and value to memory
- if (N >= 1) {
+ {
+ // compute the transposed [N, n_embd] V matrix
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
+
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+ // important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
- // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ggml_tensor * Q =
ggml_permute(ctx0,
- ggml_rope(ctx0,
- ggml_cpy(ctx0,
- Qcur,
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
- n_past, n_rot, 0),
+ Qcur,
0, 2, 1, 3);
- // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ggml_tensor * K =
ggml_permute(ctx0,
- ggml_rope(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
- n_embd/n_head, n_head, n_past + N),
- n_past, n_rot, 1),
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
+ n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// K * Q
@@ -858,18 +856,23 @@ static bool llama_eval_internal(
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
- // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
- struct ggml_tensor * V_trans =
- ggml_cpy(ctx0,
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
- n_embd/n_head, n_head, n_past + N),
- 1, 2, 0, 3),
- ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+ // split cached V into n_head heads
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_past + N, n_embd/n_head, n_head,
+ n_ctx*ggml_element_size(kv_self.v),
+ n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
+ il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
- // KQV = transpose(V) * KQ_soft_max
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+#if 1
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+#else
+ // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+ // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+ // is there a better way?
+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@@ -955,9 +958,13 @@ static bool llama_eval_internal(
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);
+ // print timing information per ggml operation (for debugging purposes)
+ // requires GGML_PERF to be defined
+ //ggml_graph_print(&gf);
+
+ // plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
- // ggml_graph_print (&gf);
- // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+ // ggml_graph_dump_dot(&gf, NULL, "llama.dot");
//}
//embd_w.resize(n_vocab*N);