From 41c674161fb2459bdf7806d1eebead15bc5d046e Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 24 Jul 2023 17:57:12 +0200 Subject: make rms_norm_eps a parameter (#2374) * make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci --- ggml-cuda.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'ggml-cuda.cu') diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b8c9835..87a1660 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-6f; - float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i norm_f32<<>>(x, dst, ncols); } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); + rms_norm_f32<<>>(x, dst, ncols, eps); } static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); (void) src1; (void) dst; -- cgit v1.2.3