diff options
Diffstat (limited to 'examples/baby-llama')
-rw-r--r-- | examples/baby-llama/baby-llama.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 5573c15..e5639da 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -79,34 +79,39 @@ struct ggml_tensor * randomize_tensor_normal( int ndims, const int64_t ne[], struct random_normal_distribution * rnd) { + float scale = 1.0; // xavier switch (ndims) { case 1: + scale /= sqrtf(ne[0]); for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i0] = frand_normal(rnd); + ((float *)tensor->data)[i0] = scale * frand_normal(rnd); } break; case 2: + scale /= sqrtf(ne[0]+ne[1]); for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd); } } break; case 3: + scale /= sqrtf(ne[0]+ne[1]); for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } } break; case 4: + scale /= sqrtf(ne[0]+ne[1]); for (int i3 = 0; i3 < ne[3]; i3++) { for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } } |