diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-05-20 15:34:45 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-20 15:34:45 +0300 |
commit | 3de84b26066d95068409c1dc79bcc41c1eea2a03 (patch) | |
tree | bbfba33243550bd0db214bbbb6e323ca5f885fef | |
parent | affc76edfdefa7b326f526e463cc65ff13fcfb92 (diff) |
ggml : add ggml_clamp() (#1539)
* ggml : add ggml_clamp()
* ggml : indentation
-rw-r--r-- | ggml.c | 158 | ||||
-rw-r--r-- | ggml.h | 14 |
2 files changed, 154 insertions, 18 deletions
@@ -3472,6 +3472,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "ROPE", "ROPE_BACK", "ALIBI", + "CLAMP", "CONV_1D_1S", "CONV_1D_2S", @@ -3482,7 +3483,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3532,6 +3534,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope(x)", "rope_back(x)", "alibi(x)", + "clamp(x)", "conv_1d_1s(x)", "conv_1d_2s(x)", @@ -3542,7 +3545,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -6214,7 +6217,8 @@ struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head) { + int n_head, + float bias_max) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6233,6 +6237,8 @@ struct ggml_tensor * ggml_alibi( ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_head; + GGML_ASSERT(sizeof(float) == sizeof(int32_t)); + (((float *) b->data)[2]) = bias_max; ggml_scratch_load(ctx); @@ -6244,6 +6250,40 @@ struct ggml_tensor * ggml_alibi( return result; } +// ggml_clamp + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + + ((float *) b->data)[0] = min; + ((float *) b->data)[1] = max; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_CLAMP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_conv_1d_1s struct ggml_tensor * ggml_conv_1d_1s( @@ -10553,6 +10593,7 @@ static void ggml_compute_forward_diag_mask_f32( const int n_past = ((int32_t *) src1->data)[0]; const bool inplace = (bool)((int32_t *) src1->data)[1]; + assert(n_past >= 0); if (!inplace && (params->type == GGML_TASK_INIT)) { @@ -10723,14 +10764,15 @@ static void ggml_compute_forward_alibi_f32( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10753,8 +10795,8 @@ static void ggml_compute_forward_alibi_f32( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10772,13 +10814,13 @@ static void ggml_compute_forward_alibi_f32( m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); } - pdst[0] = i * m_k + src[0]; + pdst[0] = (i-ne0+1) * m_k + src[0]; + } } } } - static void ggml_compute_forward_alibi_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10786,14 +10828,15 @@ static void ggml_compute_forward_alibi_f16( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10816,8 +10859,8 @@ static void ggml_compute_forward_alibi_f16( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10836,7 +10879,7 @@ static void ggml_compute_forward_alibi_f16( } // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); + pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); } } } @@ -10872,6 +10915,77 @@ static void ggml_compute_forward_alibi( } } + +// ggml_compute_forward_clamp + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 2); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int min = ((float *) src1->data)[0]; + const int max = ((float *) src1->data)[1]; + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_rope static void ggml_compute_forward_rope_f32( @@ -12853,6 +12967,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_CONV_1D_1S: { ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); @@ -13160,6 +13278,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SILU: { // necessary for llama @@ -14039,6 +14161,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; //TODO } break; + case GGML_OP_CLAMP: + { + node->n_tasks = 1; //TODO + } break; case GGML_OP_CONV_1D_1S: case GGML_OP_CONV_1D_2S: { @@ -313,6 +313,7 @@ extern "C" { GGML_OP_ROPE, GGML_OP_ROPE_BACK, GGML_OP_ALIBI, + GGML_OP_CLAMP, GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_2S, @@ -849,7 +850,7 @@ extern "C" { int n_past); // in-place, returns view(a) - GGML_API struct ggml_tensor * gml_diag_mask_zero_inplace( + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( struct ggml_context * ctx, struct ggml_tensor * a, int n_past); @@ -897,7 +898,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head); + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); // padding = 1 // TODO: we don't support extra parameters for now |