aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-07-24 17:57:12 +0200
committerGitHub <noreply@github.com>2023-07-24 17:57:12 +0200
commit41c674161fb2459bdf7806d1eebead15bc5d046e (patch)
tree0a211224c924a579287762cc7492fe1c9fcf3509
parentb3f138d05849ccbce67303ac17b50ebbc268128a (diff)
make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
-rw-r--r--examples/baby-llama/baby-llama.cpp20
-rw-r--r--examples/common.cpp8
-rw-r--r--examples/common.h23
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp32
-rw-r--r--ggml-cuda.cu13
-rw-r--r--ggml-metal.m3
-rw-r--r--ggml.c16
-rw-r--r--ggml.h7
-rw-r--r--llama.cpp20
-rw-r--r--llama.h1
-rw-r--r--tests/test-grad0.c2
11 files changed, 89 insertions, 56 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp
index 4965881..f9dc0aa 100644
--- a/examples/baby-llama/baby-llama.cpp
+++ b/examples/baby-llama/baby-llama.cpp
@@ -8,6 +8,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+static const float rms_norm_eps = 1e-6f;
+
float frand() {
return (float)rand()/(float)RAND_MAX;
}
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
@@ -685,7 +687,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
@@ -729,7 +731,7 @@ struct ggml_tensor * forward(
{
// inpL shape [n_embd,N,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
@@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
@@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
@@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
{
// inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
@@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
@@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
@@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
{
// inpL shape [n_embd,N,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
diff --git a/examples/common.cpp b/examples/common.cpp
index 779605f..0e88a12 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gqa = std::stoi(argv[i]);
+ } else if (arg == "-eps" || arg == "--rms-norm-eps") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+ fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
+ lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
diff --git a/examples/common.h b/examples/common.h
index 7086606..894a085 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -22,18 +22,19 @@
int32_t get_num_physical_cores();
struct gpt_params {
- uint32_t seed = -1; // RNG seed
+ uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
- int32_t n_predict = -1; // new tokens to predict
- int32_t n_ctx = 512; // context size
- int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
- int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
- int32_t n_gpu_layers = 0; // number of layers to store in VRAM
- int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
- float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
+ int32_t n_gpu_layers = 0; // number of layers to store in VRAM
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
+ float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ float rms_norm_eps = 1e-6; // rms norm epsilon
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 449b4e9..4bbf6b7 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+static const float rms_norm_eps = 1e-6f;
+
struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
@@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
@@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{
// inpL shape [n_embd,N,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
@@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
@@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
@@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{
// inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
@@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
@@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
@@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{
// inpL shape [n_embd,N*n_batch,1,1]
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
@@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm
{
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
@@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{
// norm
{
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
@@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm
{
- inpL = ggml_rms_norm(ctx0, inpL);
+ inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
@@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
- use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
+ use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
@@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
- use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
+ use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
@@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
}
clr_buf(0);
use_buf(0);
- struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
+ struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1);
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index b8c9835..87a1660 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-6f;
-
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) {
@@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
// compute
- rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+ rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
(void) src1;
(void) dst;
diff --git a/ggml-metal.m b/ggml-metal.m
index 1fd6e85..c1db3d1 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
- const float eps = 1e-6f;
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 512;
diff --git a/ggml.c b/ggml.c
index 960b805..11226c8 100644
--- a/ggml.c
+++ b/ggml.c
@@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ float eps,
bool inplace) {
bool is_node = false;
@@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- // TODO: maybe store epsilon here?
+ ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, false);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, true);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, true);
}
struct ggml_tensor * ggml_rms_norm_back(
@@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS;
- const float eps = 1e-6f; // TODO: make this a parameter
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
diff --git a/ggml.h b/ggml.h
index de44fba..1870b62 100644
--- a/ggml.h
+++ b/ggml.h
@@ -866,14 +866,17 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
// a - x
// b - dy
+ // TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
diff --git a/llama.cpp b/llama.cpp
index 0288f7e..b42b410 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -186,6 +186,7 @@ struct llama_hparams {
// LLaMAv2
// TODO: load from model data hparams
float f_ffn_mult = 1.0f;
+ float f_rms_norm_eps = 1e-6f;
float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f;
@@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_gqa =*/ 1,
+ /*.rms_norm_eps =*/ 1e-6f,
/*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
int n_ctx,
int n_batch,
int n_gqa,
+ float rms_norm_eps,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
@@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
auto & hparams = model.hparams;
+ // TODO: read from file
+ hparams.f_rms_norm_eps = rms_norm_eps;
+
{
switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break;
@@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
+ fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps);
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
@@ -1330,6 +1337,7 @@ static bool llama_model_load(
int n_ctx,
int n_batch,
int n_gqa,
+ float rms_norm_eps,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
@@ -1343,7 +1351,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
+ llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::exception & err) {
@@ -1396,10 +1404,12 @@ static bool llama_eval_internal(
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa();
+
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
+ const float rms_norm_eps = hparams.f_rms_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
@@ -1479,7 +1489,7 @@ static bool llama_eval_internal(
// norm
{
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_0");
@@ -1627,7 +1637,7 @@ static bool llama_eval_internal(
{
// norm
{
- cur = ggml_rms_norm(ctx0, inpFF);
+ cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_1");
@@ -1680,7 +1690,7 @@ static bool llama_eval_internal(
// norm
{
- cur = ggml_rms_norm(ctx0, inpL);
+ cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2");
@@ -3084,7 +3094,7 @@ struct llama_model * llama_load_model_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
+ if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) {
diff --git a/llama.h b/llama.h
index 81a30e1..843b0bf 100644
--- a/llama.h
+++ b/llama.h
@@ -87,6 +87,7 @@ extern "C" {
int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
+ float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
index ef20bce..6d31221 100644
--- a/tests/test-grad0.c
+++ b/tests/test-grad0.c
@@ -850,7 +850,7 @@ int main(int argc, const char ** argv) {
ggml_set_param(ctx0, x[i]);
}
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
}