aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llama.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/llama.cpp b/llama.cpp
index 7de3c19..d552192 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -727,11 +727,13 @@ static bool llama_eval_internal(
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
- n_embd/n_head, n_head, n_past + N),
- 1, 2, 0, 3);
+ ggml_cpy(ctx0,
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+ n_embd/n_head, n_head, n_past + N),
+ 1, 2, 0, 3),
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);