diff options
Diffstat (limited to 'examples/train-text-from-scratch')
-rw-r--r-- | examples/train-text-from-scratch/train-text-from-scratch.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index afbb4a7..449b4e9 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( gf->perf_time_us = 0; const auto & hparams = model->hparams; - //const int n_ctx = hparams.n_ctx; + const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; @@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd); t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); - t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); + t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch); t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); - t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); + t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch); t04->grad = expand(gb, ggml_add_inplace(ctx0, ggml_add_inplace(ctx0, |