aboutsummaryrefslogtreecommitdiff
path: root/examples/baby-llama
diff options
context:
space:
mode:
Diffstat (limited to 'examples/baby-llama')
-rw-r--r--examples/baby-llama/baby-llama.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp
index 5573c15..e5639da 100644
--- a/examples/baby-llama/baby-llama.cpp
+++ b/examples/baby-llama/baby-llama.cpp
@@ -79,34 +79,39 @@ struct ggml_tensor * randomize_tensor_normal(
int ndims,
const int64_t ne[],
struct random_normal_distribution * rnd) {
+ float scale = 1.0; // xavier
switch (ndims) {
case 1:
+ scale /= sqrtf(ne[0]);
for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i0] = frand_normal(rnd);
+ ((float *)tensor->data)[i0] = scale * frand_normal(rnd);
}
break;
case 2:
+ scale /= sqrtf(ne[0]+ne[1]);
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd);
+ ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
break;
case 3:
+ scale /= sqrtf(ne[0]+ne[1]);
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
+ ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
}
break;
case 4:
+ scale /= sqrtf(ne[0]+ne[1]);
for (int i3 = 0; i3 < ne[3]; i3++) {
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
- ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd);
+ ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd);
}
}
}