aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/CMakeLists.txt7
-rw-r--r--tests/test-double-float.cpp (renamed from tests/test-double-float.c)12
-rw-r--r--tests/test-grad0.cpp (renamed from tests/test-grad0.c)507
-rw-r--r--tests/test-opt.cpp (renamed from tests/test-opt.c)21
-rw-r--r--tests/test-sampling.cpp2
5 files changed, 432 insertions, 117 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 1acf050..1a40edb 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -1,14 +1,15 @@
function(llama_add_test source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source})
+ install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE llama)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
endfunction()
-# llama_add_test(test-double-float.c) # SLOW
+# llama_add_test(test-double-float.cpp) # SLOW
llama_add_test(test-quantize-fns.cpp)
llama_add_test(test-quantize-perf.cpp)
llama_add_test(test-sampling.cpp)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
-llama_add_test(test-grad0.c) # SLOW
-# llama_add_test(test-opt.c) # SLOW
+llama_add_test(test-grad0.cpp) # SLOW
+# llama_add_test(test-opt.cpp) # SLOW
diff --git a/tests/test-double-float.c b/tests/test-double-float.cpp
index 89dafc9..b506f27 100644
--- a/tests/test-double-float.c
+++ b/tests/test-double-float.cpp
@@ -3,10 +3,11 @@
// This is done by checking all finite (non-NaN, non-infinite) floats.
#undef NDEBUG
-#include <assert.h>
+#include <cassert>
#include <immintrin.h>
-#include <math.h>
-#include <stdint.h>
+#include <cmath>
+#include <cstdint>
+#include <cstring>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdouble-promotion"
@@ -32,8 +33,9 @@ inline static float silu_float(float x) {
int main(void) {
uint32_t x = UINT32_MAX;
do {
- float f = *(float *)&x;
- assert(!isfinite(f) || (round_orig(f) == round_float(f)));
+ float f;
+ memcpy(&f, &x, sizeof(x));
+ assert(!std::isfinite(f) || (round_orig(f) == round_float(f)));
} while (x--);
#ifdef __F16C__
diff --git a/tests/test-grad0.c b/tests/test-grad0.cpp
index 01467bc..75a698d 100644
--- a/tests/test-grad0.c
+++ b/tests/test-grad0.cpp
@@ -1,10 +1,10 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#include "ggml.h"
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <assert.h>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cassert>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -47,16 +47,16 @@
#define GGML_PRINT(...) printf(__VA_ARGS__)
-float frand(void) {
+static float frand(void) {
return (float)rand()/(float)RAND_MAX;
}
-int irand(int n) {
+static int irand(int n) {
if (n == 0) return 0;
return rand()%n;
}
-void get_random_dims(int64_t * dims, int ndims) {
+static void get_random_dims(int64_t * dims, int ndims) {
dims[0] = dims[1] = dims[2] = dims[3] = 1;
for (int i = 0; i < ndims; i++) {
@@ -64,7 +64,7 @@ void get_random_dims(int64_t * dims, int ndims) {
}
}
-struct ggml_tensor * get_random_tensor(
+static struct ggml_tensor * get_random_tensor_f32(
struct ggml_context * ctx0,
int ndims,
int64_t ne[],
@@ -112,7 +112,55 @@ struct ggml_tensor * get_random_tensor(
return result;
}
-struct ggml_tensor * get_random_tensor_int(
+static struct ggml_tensor * get_random_tensor_f16(
+ struct ggml_context * ctx0,
+ int ndims,
+ int64_t ne[],
+ float fmin,
+ float fmax) {
+ struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
+
+ switch (ndims) {
+ case 1:
+ for (int i0 = 0; i0 < ne[0]; i0++) {
+ ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+ }
+ break;
+ case 2:
+ for (int i1 = 0; i1 < ne[1]; i1++) {
+ for (int i0 = 0; i0 < ne[0]; i0++) {
+ ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+ }
+ }
+ break;
+ case 3:
+ for (int i2 = 0; i2 < ne[2]; i2++) {
+ for (int i1 = 0; i1 < ne[1]; i1++) {
+ for (int i0 = 0; i0 < ne[0]; i0++) {
+ ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+ }
+ }
+ }
+ break;
+ case 4:
+ for (int i3 = 0; i3 < ne[3]; i3++) {
+ for (int i2 = 0; i2 < ne[2]; i2++) {
+ for (int i1 = 0; i1 < ne[1]; i1++) {
+ for (int i0 = 0; i0 < ne[0]; i0++) {
+ ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+ }
+ }
+ }
+ }
+ break;
+ default:
+ assert(false);
+ };
+
+ return result;
+}
+
+static struct ggml_tensor * get_random_tensor_i32(
struct ggml_context * ctx0,
int ndims,
int64_t ne[],
@@ -160,24 +208,7 @@ struct ggml_tensor * get_random_tensor_int(
return result;
}
-float get_element(const struct ggml_tensor * t, int idx) {
- if (t->type == GGML_TYPE_F32) {
- return ((float *)t->data)[idx];
- }
-
- if (t->type == GGML_TYPE_I32) {
- return ((int32_t *)t->data)[idx];
- }
-
- assert(false);
- return INFINITY;
-}
-
-void set_element(struct ggml_tensor * t, int idx, float value) {
- ((float *)t->data)[idx] = value;
-}
-
-void print_elements(const char* label, const struct ggml_tensor * t) {
+static void print_elements(const char* label, const struct ggml_tensor * t) {
if (!t) {
printf("%s: %s = null\n", __func__, label);
return;
@@ -186,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
printf("%s: %s = [", __func__, label);
for (int k = 0; k < nelements; ++k) {
if (k > 0) { printf(", "); }
- printf("%.5f", get_element(t, k));
+ printf("%.5f", ggml_get_f32_1d(t, k));
}
printf("] shape: [");
for (int k = 0; k < t->n_dims; ++k) {
@@ -197,7 +228,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
}
-bool check_gradient(
+static bool check_gradient(
const char * op_name,
struct ggml_context * ctx0,
struct ggml_tensor * x[],
@@ -237,23 +268,23 @@ bool check_gradient(
const int nelements = ggml_nelements(x[i]);
for (int k = 0; k < nelements; ++k) {
// compute gradient using finite differences
- const float x0 = get_element(x[i], k);
+ const float x0 = ggml_get_f32_1d(x[i], k);
const float xm = x0 - eps;
const float xp = x0 + eps;
- set_element(x[i], k, xp);
+ ggml_set_f32_1d(x[i], k, xp);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0);
- set_element(x[i], k, xm);
+ ggml_set_f32_1d(x[i], k, xm);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps);
- set_element(x[i], k, x0);
+ ggml_set_f32_1d(x[i], k, x0);
// compute gradient using backward graph
ggml_graph_reset (&gf);
@@ -261,7 +292,7 @@ bool check_gradient(
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
- const float g1 = get_element(x[i]->grad, k);
+ const float g1 = ggml_get_f32_1d(x[i]->grad, k);
const float error_abs = fabsf(g0 - g1);
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
@@ -279,7 +310,7 @@ bool check_gradient(
}
// TODO: clean-up this ..
-bool check_mat_mul(
+static bool check_mat_mul(
const struct ggml_tensor * y,
const struct ggml_tensor * x0,
const struct ggml_tensor * x1) {
@@ -342,9 +373,9 @@ bool check_mat_mul(
int main(int argc, const char ** argv) {
struct ggml_init_params params = {
- .mem_size = 128*1024*1024,
- .mem_buffer = NULL,
- .no_alloc = false,
+ /* .mem_size = */ 128*1024*1024,
+ /* .mem_buffer = */ NULL,
+ /* .no_alloc = */ false,
};
int64_t ne[4];
@@ -392,19 +423,35 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * x[MAX_NARGS];
- // add
+ // add f32
{
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
- check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+ check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+ }
+ }
+
+ // add f16
+ {
+ const int nargs = 2;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+ check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
}
}
@@ -414,7 +461,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -430,7 +477,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -446,7 +493,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -462,7 +509,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -478,7 +525,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -494,7 +541,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -510,7 +557,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -527,7 +574,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -537,6 +584,40 @@ int main(int argc, const char ** argv) {
}
}
+ // mean, not yet fully implemented
+ if(0)
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
+
+ check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
+ // argmax
+ if (0)
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
+
+ check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
// repeat
{
int64_t ne2[4];
@@ -549,15 +630,36 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
- x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
}
+ }
+
+ // repeat back
+ {
+ int64_t ne2[4];
+ get_random_dims(ne2, 4);
+
+ ne2[0] = ne[0] * ne2[0];
+ ne2[1] = ne[1] * ne2[1];
+ ne2[2] = 1;
+ ne2[3] = 1;
+
+ const int nargs = 1;
+ for (int ndims = 1; ndims <= 2; ++ndims) {
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[0]);
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
+ check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+ }
}
// abs (finite differences do not work)
@@ -566,7 +668,7 @@ int main(int argc, const char ** argv) {
// for (int ndims = 1; ndims <= 2; ++ndims) {
// for (int i = 0; i < nargs; ++i) {
- // x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ // x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
// ggml_set_param(ctx0, x[i]);
// }
@@ -576,17 +678,82 @@ int main(int argc, const char ** argv) {
// }
//}
+ // sgn
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
+
+ check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
+ // neg
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
+
+ check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
+ // step
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
+
+ check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
+ // tanh, not yet fully implemented
+ if(0)
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
+
+ check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
// mul_mat
{
const int nargs = 2;
for (int ndims = 2; ndims <= 2; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
{
int64_t ne2[4];
get_random_dims(ne2, 4);
ne2[0] = ne[0];
- x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
}
ggml_set_param(ctx0, x[0]);
@@ -602,13 +769,63 @@ int main(int argc, const char ** argv) {
}
}
+ // elu, not yet fully implemented
+ if(0)
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
+
+ check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
+ // relu
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
+
+ check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+ }
+ }
+
+ // gelu, not yet fully implemented
+ if(0)
+ {
+ const int nargs = 1;
+
+ for (int ndims = 1; ndims <= 4; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+
+ struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
+
+ check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+ }
+ }
+
// silu
{
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
@@ -629,11 +846,11 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
}
@@ -647,8 +864,8 @@ int main(int argc, const char ** argv) {
ne2[0] = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
- x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]);
@@ -659,20 +876,37 @@ int main(int argc, const char ** argv) {
}
}
- // cpy
+ // cpy f32
{
const int nargs = 2;
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
- x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
- check_gradient("cpy", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+ check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+ }
+ }
+
+ // cpy f16
+ {
+ const int nargs = 2;
+
+ for (int ndims = 1; ndims <= 2; ++ndims) {
+ for (int i = 0; i < nargs; ++i) {
+ x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+ ggml_set_param(ctx0, x[i]);
+ }
+ // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
+
+ check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
}
}
@@ -689,8 +923,8 @@ int main(int argc, const char ** argv) {
for (int i = 0; i < ndims; ++i) {
ne2[0] *= ne[i];
}
- x[0] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
- x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -712,8 +946,8 @@ int main(int argc, const char ** argv) {
for (int i = 0; i < ndims; ++i) {
ne2[0] *= ne[i];
}
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
- x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -729,7 +963,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 1);
@@ -737,7 +971,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 1);
}
- x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@@ -758,7 +992,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2;
for (int ndims = 2; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 2);
@@ -766,7 +1000,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 2);
}
- x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -790,7 +1024,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2;
for (int ndims = 3; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 3);
@@ -798,7 +1032,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 3);
}
- x[1] = get_random_tensor(ctx0, 3, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -824,7 +1058,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2;
for (int ndims = 4; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 4);
@@ -832,7 +1066,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4);
}
- x[1] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -858,7 +1092,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 1);
@@ -866,7 +1100,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 1);
}
- x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@@ -887,7 +1121,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
for (int ndims = 2; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 2);
@@ -895,7 +1129,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 2);
}
- x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -915,7 +1149,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -941,7 +1175,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
get_random_dims(ne2, 2);
while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
@@ -971,7 +1205,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
get_random_dims(ne2, 3);
while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
@@ -1010,7 +1244,7 @@ int main(int argc, const char ** argv) {
for (int i=ndims; i<4; ++i) {
ne2[i] = 1;
}
- x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -1043,7 +1277,7 @@ int main(int argc, const char ** argv) {
for (int i=ndims; i<4; ++i) {
ne2[i] = 1;
}
- x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -1060,8 +1294,8 @@ int main(int argc, const char ** argv) {
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
const int nargs = 1;
const int ndims = 2;
- x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
- x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
ggml_set_param(ctx0, x[0]);
@@ -1075,7 +1309,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
const int ndims = 2;
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
int n_past = irand(ne[0]);
@@ -1090,7 +1324,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1;
const int ndims = 2;
- x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
int n_past = irand(ne[0]);
@@ -1108,7 +1342,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4);
for (int ndims = 1; ndims <= 3; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
@@ -1125,8 +1359,8 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4);
for (int ndims = 1; ndims <= 3; ++ndims) {
- x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
- x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
@@ -1136,7 +1370,7 @@ int main(int argc, const char ** argv) {
}
}
- // rope
+ // rope f32
{
const int nargs = 1;
@@ -1148,7 +1382,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 3; ndims <= 4; ++ndims) {
for (int mode = 0; mode < 4; ++mode) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
- x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
@@ -1163,14 +1397,89 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
- GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
- check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+ GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+ check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+ }
+ }
+ }
+ }
+
+ // rope f16
+ {
+ const int nargs = 1;
+
+ int64_t ne2[4];
+ get_random_dims(ne2, 4);
+ ne2[0] += ne2[0] % 2;
+ int n_rot = ne2[0];
+
+ for (int ndims = 3; ndims <= 4; ++ndims) {
+ for (int mode = 0; mode < 4; ++mode) {
+ for (int n_past = 1; n_past < ne2[2]; ++n_past) {
+ x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
+
+ ggml_set_param(ctx0, x[0]);
+
+ const bool skip_past = (mode & 1);
+ if (skip_past) {
+ // we have no past, so this would have to work on uninitialized memory.
+ // we only test the gradients here;
+ // skip_past should have no influence on gradient computation.
+ // so when other modes work, we assume that this does as well.
+ continue;
+ }
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+
+ GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+ check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+ }
+ }
+ }
+ }
+
+ // flash_attn f32
+ {
+ const int nargs = 3;
+
+ int64_t ne2[4];
+
+ get_random_dims(ne2, 4);
+ int64_t D = ne2[0];
+ int64_t N = ne2[1];
+ int64_t M = ne2[2] + N;
+ int64_t B = ne2[3];
+
+ for (int masked = 0; masked <= 1; ++masked) {
+ for (int ndims = 2; ndims <= 4; ++ndims) {
+ int64_t neq[4] = { D, N, B, ne[3] };
+ int64_t nek[4] = { D, M, B, ne[3] };
+ int64_t nev[4] = { M, D, B, ne[3] };
+ if (ndims == 2) {
+ neq[2] = 1; neq[3] = 1;
+ nek[2] = 1; nek[3] = 1;
+ nev[2] = 1; nev[3] = 1;
+ } else if (ndims == 3) {
+ neq[3] = 1;
+ nek[3] = 1;
+ nev[3] = 1;
}
+ x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+ x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+ x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+ ggml_set_param(ctx0, x[0]);
+ ggml_set_param(ctx0, x[1]);
+ ggml_set_param(ctx0, x[2]);
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+ check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
}
}
}
- // flash_attn
+ // flash_attn f16, not yet fully implemented
+ if(0)
{
const int nargs = 3;
@@ -1196,16 +1505,16 @@ int main(int argc, const char ** argv) {
nek[3] = 1;
nev[3] = 1;
}
- x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
- x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
- x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
+ x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
+ x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
+ x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]);
ggml_set_param(ctx0, x[2]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
- check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+ check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
}
}
}
diff --git a/tests/test-opt.c b/tests/test-opt.cpp
index 5531814..8ab2402 100644
--- a/tests/test-opt.c
+++ b/tests/test-opt.cpp
@@ -1,9 +1,9 @@
#include "ggml.h"
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <assert.h>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cassert>
#define MAX_NARGS 2
@@ -119,15 +119,16 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
int main(void) {
struct ggml_init_params params = {
- .mem_size = 1024*1024*1024,
- .mem_buffer = NULL,
- .no_alloc = false,
+ /* .mem_size = */ 1024*1024*1024,
+ /* .mem_buffer = */ NULL,
+ /* .no_alloc = */ false,
};
+
struct ggml_context * ctx = ggml_init(params);
- int64_t ne1[4] = {4, 1024, 1, 1};
- int64_t ne2[4] = {4, 2048, 1, 1};;
- int64_t ne3[4] = {1024, 2048, 1, 1};
+ int64_t ne1[4] = {4, 128, 1, 1};
+ int64_t ne2[4] = {4, 256, 1, 1};;
+ int64_t ne3[4] = {128, 256, 1, 1};
struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index 64f9455..4437c39 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -200,4 +200,6 @@ int main(void) {
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
printf("OK\n");
+
+ return 0;
}