aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test-grad0.c60
1 files changed, 59 insertions, 1 deletions
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
index ec50592..c8c2c0f 100644
--- a/tests/test-grad0.c
+++ b/tests/test-grad0.c
@@ -5,7 +5,7 @@
#include <stdlib.h>
#include <assert.h>
-#define MAX_NARGS 2
+#define MAX_NARGS 3
#undef MIN
#undef MAX
@@ -1090,6 +1090,25 @@ int main(int argc, const char ** argv) {
}
}
+ // cross_entropy_loss
+ {
+ const int nargs = 1;
+
+ int64_t ne2[4];
+ get_random_dims(ne2, 4);
+
+ for (int ndims = 1; ndims <= 3; ++ndims) {
+ x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+ x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f);
+ ggml_set_param(ctx0, x[0]);
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
+
+ check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
+ // finite differences regularly fails!
+ }
+ }
+
// rope
{
const int nargs = 1;
@@ -1124,6 +1143,45 @@ int main(int argc, const char ** argv) {
}
}
+ // flash_attn
+ {
+ const int nargs = 3;
+
+ int64_t ne2[4];
+
+ get_random_dims(ne2, 4);
+ int64_t D = ne2[0];
+ int64_t N = ne2[1];
+ int64_t M = ne2[2] + N;
+ int64_t B = ne2[3];
+
+ for (int masked = 0; masked <= 1; ++masked) {
+ for (int ndims = 2; ndims <= 4; ++ndims) {
+ int64_t neq[4] = { D, N, B, ne[3] };
+ int64_t nek[4] = { D, M, B, ne[3] };
+ int64_t nev[4] = { M, D, B, ne[3] };
+ if (ndims == 2) {
+ neq[2] = 1; neq[3] = 1;
+ nek[2] = 1; nek[3] = 1;
+ nev[2] = 1; nev[3] = 1;
+ } else if (ndims == 3) {
+ neq[3] = 1;
+ nek[3] = 1;
+ nev[3] = 1;
+ }
+ x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
+ x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
+ x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
+ ggml_set_param(ctx0, x[0]);
+ ggml_set_param(ctx0, x[1]);
+ ggml_set_param(ctx0, x[2]);
+
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+ check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+ }
+ }
+ }
ggml_free(ctx0);
}