aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-21 14:51:34 +0300
committerGeorgi Gerganov <ggerganov@gmail.com>2023-07-21 14:51:34 +0300
commit513f8619535a64fa9ace808cdcbcf66211535f5c (patch)
tree2735871213e737a2183f85265add5c4120e190dc
parent3973b25a64a37a47eac156a3fd28f83c16f14bf2 (diff)
ggml : fix rope args order + assert (#2054)
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp6
-rw-r--r--ggml.c24
-rw-r--r--ggml.h7
-rw-r--r--llama.cpp4
4 files changed, 23 insertions, 18 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index afbb4a7..449b4e9 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
gf->perf_time_us = 0;
const auto & hparams = model->hparams;
- //const int n_ctx = hparams.n_ctx;
+ const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
@@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
- t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
+ t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
- t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
+ t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
t04->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_add_inplace(ctx0,
diff --git a/ggml.c b/ggml.c
index c56a3d0..7ecabc5 100644
--- a/ggml.c
+++ b/ggml.c
@@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
int n_past,
int n_dims,
int mode,
+ int n_ctx,
float freq_base,
float freq_scale,
- int n_ctx,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
@@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
}
struct ggml_tensor * ggml_rope_inplace(
@@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
}
struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
int n_past,
int n_dims,
int mode,
+ int n_ctx,
float freq_base,
- float freq_scale,
- int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
+ float freq_scale) {
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
}
// ggml_rope_back
@@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * a,
int n_past,
int n_dims,
- int mode) {
+ int mode,
+ int n_ctx) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
@@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
ggml_scratch_save(ctx);
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
ggml_set_name(b, "n_past, n_dims, mode");
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_dims;
((int32_t *) b->data)[2] = mode;
+ ((int32_t *) b->data)[3] = n_ctx;
ggml_scratch_load(ctx);
@@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
+ const int n_ctx = ((int32_t *) src1->data)[3];
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
n_past,
n_dims,
- mode),
+ mode,
+ n_ctx),
inplace);
}
if (src1->grad) {
@@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
- assert(ggml_nelements(src1) == 3);
+ assert(ggml_nelements(src1) == 4);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
diff --git a/ggml.h b/ggml.h
index 24856a2..5023b16 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1128,9 +1128,9 @@ extern "C" {
int n_past,
int n_dims,
int mode,
+ int n_ctx,
float freq_base,
- float freq_scale,
- int n_ctx);
+ float freq_scale);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
@@ -1139,7 +1139,8 @@ extern "C" {
struct ggml_tensor * a,
int n_past,
int n_dims,
- int mode);
+ int mode,
+ int n_ctx);
// alibi position embedding
// in-place, returns view(a)
diff --git a/llama.cpp b/llama.cpp
index 3b0024e..0a381af 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1452,11 +1452,11 @@ static bool llama_eval_internal(
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
- struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
+ struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
- struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
+ struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");