aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-06-24 19:40:18 +0300
committerGeorgi Gerganov <ggerganov@gmail.com>2023-06-24 19:40:18 +0300
commit65bdd52a867539691007f85c5508146d507f72c1 (patch)
tree4249b54aeb868112463d245632b3987c59785009 /tests
parentfdd18609113862dc6eb34dfc44a093d54c59ff1f (diff)
tests : sync test-grad0 from ggml
Diffstat (limited to 'tests')
-rw-r--r--tests/test-grad0.c20
1 files changed, 20 insertions, 0 deletions
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
index c8c2c0f..b5a499c 100644
--- a/tests/test-grad0.c
+++ b/tests/test-grad0.c
@@ -1,3 +1,4 @@
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#include "ggml.h"
#include <math.h>
@@ -5,6 +6,10 @@
#include <stdlib.h>
#include <assert.h>
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
#define MAX_NARGS 3
#undef MIN
@@ -197,8 +202,23 @@ bool check_gradient(
float max_error_abs,
float max_error_rel) {
+ static int n_threads = -1;
+ if (n_threads < 0) {
+ n_threads = GGML_DEFAULT_N_THREADS;
+
+ const char *env = getenv("GGML_N_THREADS");
+ if (env) {
+ n_threads = atoi(env);
+ }
+
+ printf("GGML_N_THREADS = %d\n", n_threads);
+ }
+
struct ggml_cgraph gf = ggml_build_forward (f);
+ gf.n_threads = n_threads;
+
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+ gb.n_threads = n_threads;
ggml_graph_compute(ctx0, &gf);
ggml_graph_reset (&gf);