diff options
-rw-r--r-- | ggml.c | 82 | ||||
-rw-r--r-- | ggml.h | 7 |
2 files changed, 76 insertions, 13 deletions
@@ -6778,6 +6778,7 @@ struct ggml_tensor * ggml_rope_impl( int n_past, int n_dims, int mode, + int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6790,11 +6791,12 @@ struct ggml_tensor * ggml_rope_impl( ggml_scratch_save(ctx); - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_dims; ((int32_t *) b->data)[2] = mode; + ((int32_t *) b->data)[3] = n_ctx; ggml_scratch_load(ctx); @@ -6811,8 +6813,9 @@ struct ggml_tensor * ggml_rope( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false); + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6820,8 +6823,9 @@ struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true); + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true); } // ggml_rope_back @@ -12440,7 +12444,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); + GGML_ASSERT(ggml_nelements(src1) == 4); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12449,6 +12453,7 @@ static void ggml_compute_forward_rope_f32( const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; assert(n_past >= 0); @@ -12493,6 +12498,7 @@ static void ggml_compute_forward_rope_f32( const float theta_scale = powf(10000.0, -2.0f/n_dims); const bool is_neox = mode & 2; + const bool is_glm = mode & 4; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { @@ -12503,7 +12509,32 @@ static void ggml_compute_forward_rope_f32( float theta = (float)p; - if (!is_neox) { + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + const float x2 = src[n_dims]; + const float x3 = src[n_dims/2*3]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; + dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; + } + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -12553,7 +12584,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); + GGML_ASSERT(ggml_nelements(src1) == 4); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12562,6 +12593,7 @@ static void ggml_compute_forward_rope_f16( const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; assert(n_past >= 0); @@ -12606,6 +12638,7 @@ static void ggml_compute_forward_rope_f16( const float theta_scale = powf(10000.0, -2.0f/n_dims); const bool is_neox = mode & 2; + const bool is_glm = mode & 4; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { @@ -12616,7 +12649,32 @@ static void ggml_compute_forward_rope_f16( float theta = (float)p; - if (!is_neox) { + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x2 = GGML_FP16_TO_FP32(src[n_dims]); + const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); + dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); + } + } if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -16189,17 +16247,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + assert(ggml_nelements(src1) == 4); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope(ctx, tensor->grad, n_past, n_dims, - mode), + mode, + n_ctx), inplace); } if (src1->grad) { @@ -1036,13 +1036,15 @@ extern "C" { // rotary position embedding // if mode & 1 == 1, skip n_past elements // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style // TODO: avoid creating a new tensor every time GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_inplace( @@ -1050,7 +1052,8 @@ extern "C" { struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // rotary position embedding backward, i.e compute dx from dy // a - dy |