aboutsummaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-07-24 17:57:12 +0200
committerGitHub <noreply@github.com>2023-07-24 17:57:12 +0200
commit41c674161fb2459bdf7806d1eebead15bc5d046e (patch)
tree0a211224c924a579287762cc7492fe1c9fcf3509 /ggml.c
parentb3f138d05849ccbce67303ac17b50ebbc268128a (diff)
make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c16
1 files changed, 10 insertions, 6 deletions
diff --git a/ggml.c b/ggml.c
index 960b805..11226c8 100644
--- a/ggml.c
+++ b/ggml.c
@@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ float eps,
bool inplace) {
bool is_node = false;
@@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- // TODO: maybe store epsilon here?
+ ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, false);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, true);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, true);
}
struct ggml_tensor * ggml_rms_norm_back(
@@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS;
- const float eps = 1e-6f; // TODO: make this a parameter
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {