From b1f429095328a34556c0e9a7a2fefced3db3368c Mon Sep 17 00:00:00 2001 From: wzy <32936898+Freed-Wu@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:01:11 +0800 Subject: cmake : install targets (#2256) fix #2252 --- examples/train-text-from-scratch/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) (limited to 'examples/train-text-from-scratch') diff --git a/examples/train-text-from-scratch/CMakeLists.txt b/examples/train-text-from-scratch/CMakeLists.txt index 1a44c49..4459516 100644 --- a/examples/train-text-from-scratch/CMakeLists.txt +++ b/examples/train-text-from-scratch/CMakeLists.txt @@ -1,4 +1,5 @@ set(TARGET train-text-from-scratch) add_executable(${TARGET} train-text-from-scratch.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) -- cgit v1.2.3 From 513f8619535a64fa9ace808cdcbcf66211535f5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Jul 2023 14:51:34 +0300 Subject: ggml : fix rope args order + assert (#2054) --- .../train-text-from-scratch.cpp | 6 +++--- ggml.c | 24 +++++++++++++--------- ggml.h | 7 ++++--- llama.cpp | 4 ++-- 4 files changed, 23 insertions(+), 18 deletions(-) (limited to 'examples/train-text-from-scratch') diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index afbb4a7..449b4e9 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( gf->perf_time_us = 0; const auto & hparams = model->hparams; - //const int n_ctx = hparams.n_ctx; + const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; @@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd); t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); - t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); + t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch); t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); - t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); + t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch); t04->grad = expand(gb, ggml_add_inplace(ctx0, ggml_add_inplace(ctx0, diff --git a/ggml.c b/ggml.c index c56a3d0..7ecabc5 100644 --- a/ggml.c +++ b/ggml.c @@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl( int n_past, int n_dims, int mode, + int n_ctx, float freq_base, float freq_scale, - int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); } struct ggml_tensor * ggml_rope_inplace( @@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true); + float freq_scale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); } // ggml_rope_back @@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { + int mode, + int n_ctx) { GGML_ASSERT(n_past >= 0); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back( ggml_scratch_save(ctx); - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); ggml_set_name(b, "n_past, n_dims, mode"); ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_dims; ((int32_t *) b->data)[2] = mode; + ((int32_t *) b->data)[3] = n_ctx; ggml_scratch_load(ctx); @@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; + const int n_ctx = ((int32_t *) src1->data)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, n_past, n_dims, - mode), + mode, + n_ctx), inplace); } if (src1->grad) { @@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + assert(ggml_nelements(src1) == 4); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; diff --git a/ggml.h b/ggml.h index 24856a2..5023b16 100644 --- a/ggml.h +++ b/ggml.h @@ -1128,9 +1128,9 @@ extern "C" { int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx); + float freq_scale); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1139,7 +1139,8 @@ extern "C" { struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // alibi position embedding // in-place, returns view(a) diff --git a/llama.cpp b/llama.cpp index 3b0024e..0a381af 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1452,11 +1452,11 @@ static bool llama_eval_internal( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); -- cgit v1.2.3 From 41c674161fb2459bdf7806d1eebead15bc5d046e Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 24 Jul 2023 17:57:12 +0200 Subject: make rms_norm_eps a parameter (#2374) * make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci --- examples/baby-llama/baby-llama.cpp | 20 ++++++++------ examples/common.cpp | 8 ++++++ examples/common.h | 23 ++++++++-------- .../train-text-from-scratch.cpp | 32 ++++++++++++---------- ggml-cuda.cu | 13 +++++---- ggml-metal.m | 3 +- ggml.c | 16 +++++++---- ggml.h | 7 +++-- llama.cpp | 20 ++++++++++---- llama.h | 1 + tests/test-grad0.c | 2 +- 11 files changed, 89 insertions(+), 56 deletions(-) (limited to 'examples/train-text-from-scratch') diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 4965881..f9dc0aa 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -8,6 +8,8 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const float rms_norm_eps = 1e-6f; + float frand() { return (float)rand()/(float)RAND_MAX; } @@ -562,7 +564,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -685,7 +687,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); // cur = ffn_norm*cur // cur shape [n_embd,N,1,1] @@ -729,7 +731,7 @@ struct ggml_tensor * forward( { // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // inpL = norm*inpL // inpL shape [n_embd,N,1,1] @@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch( { // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); // cur = ffn_norm*cur // cur shape [n_embd,N,1,1] @@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora( { // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // inpL = norm*inpL // inpL shape [n_embd,N,1,1] diff --git a/examples/common.cpp b/examples/common.cpp index 779605f..0e88a12 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_gqa = std::stoi(argv[i]); + } else if (arg == "-eps" || arg == "--rms-norm-eps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rms_norm_eps = std::stof(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); + fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); @@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.n_ctx = params.n_ctx; lparams.n_batch = params.n_batch; lparams.n_gqa = params.n_gqa; + lparams.rms_norm_eps = params.rms_norm_eps; lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; lparams.tensor_split = params.tensor_split; diff --git a/examples/common.h b/examples/common.h index 7086606..894a085 100644 --- a/examples/common.h +++ b/examples/common.h @@ -22,18 +22,19 @@ int32_t get_num_physical_cores(); struct gpt_params { - uint32_t seed = -1; // RNG seed + uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_gpu_layers = 0; // number of layers to store in VRAM - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_gpu_layers = 0; // number of layers to store in VRAM + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + float rms_norm_eps = 1e-6; // rms norm epsilon float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 449b4e9..4bbf6b7 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -16,6 +16,8 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const float rms_norm_eps = 1e-6f; + struct random_normal_distribution { std::mt19937 gen; std::normal_distribution rd; @@ -439,7 +441,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -562,7 +564,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); // cur = ffn_norm*cur // cur shape [n_embd,N,1,1] @@ -606,7 +608,7 @@ struct ggml_tensor * forward( { // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // inpL = norm*inpL // inpL shape [n_embd,N,1,1] @@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch( { // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache( { // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // norm { - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( struct my_llama_layer & layer = model->layers[il]; // tensors with values necessary for backward pass are in persistent buf(-1) // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. - use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); @@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); @@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( } clr_buf(0); use_buf(0); - struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); use_buf(-1); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b8c9835..87a1660 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-6f; - float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i norm_f32<<>>(x, dst, ncols); } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); + rms_norm_f32<<>>(x, dst, ncols, eps); } static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); (void) src1; (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 1fd6e85..c1db3d1 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -812,7 +812,8 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const float eps = 1e-6f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 512; diff --git a/ggml.c b/ggml.c index 960b805..11226c8 100644 --- a/ggml.c +++ b/ggml.c @@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace( static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - // TODO: maybe store epsilon here? + ggml_set_op_params(result, &eps, sizeof(eps)); result->op = GGML_OP_RMS_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); } struct ggml_tensor * ggml_rms_norm_back( @@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { diff --git a/ggml.h b/ggml.h index de44fba..1870b62 100644 --- a/ggml.h +++ b/ggml.h @@ -866,14 +866,17 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); // a - x // b - dy + // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index 0288f7e..b42b410 100644 --- a/llama.cpp +++ b/llama.cpp @@ -186,6 +186,7 @@ struct llama_hparams { // LLaMAv2 // TODO: load from model data hparams float f_ffn_mult = 1.0f; + float f_rms_norm_eps = 1e-6f; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; @@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() { /*.n_ctx =*/ 512, /*.n_batch =*/ 512, /*.n_gqa =*/ 1, + /*.rms_norm_eps =*/ 1e-6f, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -1000,6 +1002,7 @@ static void llama_model_load_internal( int n_ctx, int n_batch, int n_gqa, + float rms_norm_eps, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1024,6 +1027,9 @@ static void llama_model_load_internal( auto & hparams = model.hparams; + // TODO: read from file + hparams.f_rms_norm_eps = rms_norm_eps; + { switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1072,6 +1078,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); @@ -1330,6 +1337,7 @@ static bool llama_model_load( int n_ctx, int n_batch, int n_gqa, + float rms_norm_eps, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1343,7 +1351,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1396,10 +1404,12 @@ static bool llama_eval_internal( const int64_t n_vocab = hparams.n_vocab; const int64_t n_embd_gqa = hparams.n_embd_gqa(); + LLAMA_ASSERT(n_embd_head == hparams.n_rot); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const float rms_norm_eps = hparams.f_rms_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -1479,7 +1489,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_0"); @@ -1627,7 +1637,7 @@ static bool llama_eval_internal( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_1"); @@ -1680,7 +1690,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2"); @@ -3084,7 +3094,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers, + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { diff --git a/llama.h b/llama.h index 81a30e1..843b0bf 100644 --- a/llama.h +++ b/llama.h @@ -87,6 +87,7 @@ extern "C" { int32_t n_ctx; // text context int32_t n_batch; // prompt processing batch size int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) + float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams) int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors diff --git a/tests/test-grad0.c b/tests/test-grad0.c index ef20bce..6d31221 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -850,7 +850,7 @@ int main(int argc, const char ** argv) { ggml_set_param(ctx0, x[i]); } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0])); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); } -- cgit v1.2.3 From eb542d39324574a6778fad9ba9e34ba7a14a82a3 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:35:53 +0300 Subject: Add LLAMA_DEFAULT_RMS_EPS so we can change the default (#2384) Co-authored-by: Iwan Kawrakow --- examples/baby-llama/baby-llama.cpp | 6 +++++- examples/common.h | 2 +- examples/train-text-from-scratch/train-text-from-scratch.cpp | 2 +- llama.cpp | 4 ++-- llama.h | 4 ++++ 5 files changed, 13 insertions(+), 5 deletions(-) (limited to 'examples/train-text-from-scratch') diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index f9dc0aa..6fa55b3 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -8,7 +8,11 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const float rms_norm_eps = 1e-6f; +#ifdef LLAMA_DEFAULT_RMS_EPS +static const float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; +#else +static const float rms_norm_eps = 5e-6f; +#endif float frand() { return (float)rand()/(float)RAND_MAX; diff --git a/examples/common.h b/examples/common.h index 2d87c92..672dcf7 100644 --- a/examples/common.h +++ b/examples/common.h @@ -34,7 +34,7 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - float rms_norm_eps = 1e-6; // rms norm epsilon + float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 4bbf6b7..54dc2be 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -16,7 +16,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const float rms_norm_eps = 1e-6f; +static const float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; struct random_normal_distribution { std::mt19937 gen; diff --git a/llama.cpp b/llama.cpp index febefba..30d4b0a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -186,7 +186,7 @@ struct llama_hparams { // LLaMAv2 // TODO: load from model data hparams float f_ffn_mult = 1.0f; - float f_rms_norm_eps = 1e-6f; + float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; @@ -870,7 +870,7 @@ struct llama_context_params llama_context_default_params() { /*.n_ctx =*/ 512, /*.n_batch =*/ 512, /*.n_gqa =*/ 1, - /*.rms_norm_eps =*/ 1e-6f, + /*.rms_norm_eps =*/ LLAMA_DEFAULT_RMS_EPS, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, diff --git a/llama.h b/llama.h index 843b0bf..df46f9b 100644 --- a/llama.h +++ b/llama.h @@ -53,6 +53,10 @@ #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif +#ifndef LLAMA_DEFAULT_RMS_EPS +#define LLAMA_DEFAULT_RMS_EPS 5e-6f +#endif + #ifdef __cplusplus extern "C" { #endif -- cgit v1.2.3