aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhoangmit <hoangmit@users.noreply.github.com>2023-03-15 18:41:38 -0400
committerGitHub <noreply@github.com>2023-03-16 00:41:38 +0200
commit6eac39ba953acaeec396cea2969dbf413907e2ec (patch)
tree1c263cf2672564d1fba72b2d7e47037cea8345c3
parent27944c4206a49bbe003021a2610bacaa3044e619 (diff)
Add RMS norm and use it (#187)
* add ggml_rms_norm * update op num
-rw-r--r--ggml.c128
-rw-r--r--ggml.h5
-rw-r--r--main.cpp6
3 files changed, 134 insertions, 5 deletions
diff --git a/ggml.c b/ggml.c
index a0c0dd0..eee54f7 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"GELU",
"SILU",
"NORM",
+ "RMS_NORM",
"MUL_MAT",
@@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};
-static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gelu(x)",
"silu(x)",
"norm(x)",
+ "rms_norm(x)",
"X*Y",
@@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};
-static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
//
// ggml object
@@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
return ggml_norm_impl(ctx, a, true);
}
+struct ggml_tensor * ggml_rms_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_RMS_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL; // TODO: maybe store epsilon here?
+
+ return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_rms_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_rms_norm_impl(ctx, a, true);
+}
+
// ggml_mul_mat
struct ggml_tensor * ggml_mul_mat(
@@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm(
}
}
+static void ggml_compute_forward_rms_norm_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int ne00 = src0->ne[0];
+ const int ne01 = src0->ne[1];
+ const int ne02 = src0->ne[2];
+ const int ne03 = src0->ne[3];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb02 = src0->nb[2];
+ const size_t nb03 = src0->nb[3];
+
+ const size_t nb1 = dst->nb[1];
+ const size_t nb2 = dst->nb[2];
+ const size_t nb3 = dst->nb[3];
+
+ const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+ // TODO: optimize
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float mean = 0.0;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ mean += x[i00] * x[i00];
+ }
+
+ mean /= ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ memcpy(y, x, ne00 * sizeof(float));
+ // for (int i00 = 0; i00 < ne00; i00++) {
+ // y[i00] = x[i00];
+ // }
+
+ const float scale = 1.0/sqrt(mean + eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rms_norm(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rms_norm_f32(params, src0, dst);
+ } break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
// ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_norm(params, tensor->src0, tensor);
} break;
+ case GGML_OP_RMS_NORM:
+ {
+ ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
+ } break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_RMS_NORM:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_MUL_MAT:
{
if (src0->grad) {
diff --git a/ggml.h b/ggml.h
index 7ce655c..bac4fe6 100644
--- a/ggml.h
+++ b/ggml.h
@@ -230,6 +230,7 @@ enum ggml_op {
GGML_OP_GELU,
GGML_OP_SILU,
GGML_OP_NORM, // normalize
+ GGML_OP_RMS_NORM,
GGML_OP_MUL_MAT,
@@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
+struct ggml_tensor * ggml_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
// A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows
diff --git a/main.cpp b/main.cpp
index a812d0f..ca0fca8 100644
--- a/main.cpp
+++ b/main.cpp
@@ -588,7 +588,7 @@ bool llama_eval(
// norm
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
@@ -678,7 +678,7 @@ bool llama_eval(
{
// norm
{
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF);
// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
@@ -713,7 +713,7 @@ bool llama_eval(
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL);
// inpL = norm*inpL
inpL = ggml_mul(ctx0,