diff --git a/tests/test-grad0.c b/tests/test-grad0.c index c8c2c0f..b5a499c 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -1,3 +1,4 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #include "ggml.h" #include @@ -5,6 +6,10 @@ #include #include +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + #define MAX_NARGS 3 #undef MIN @@ -197,8 +202,23 @@ bool check_gradient( float max_error_abs, float max_error_rel) { + static int n_threads = -1; + if (n_threads < 0) { + n_threads = GGML_DEFAULT_N_THREADS; + + const char *env = getenv("GGML_N_THREADS"); + if (env) { + n_threads = atoi(env); + } + + printf("GGML_N_THREADS = %d\n", n_threads); + } + struct ggml_cgraph gf = ggml_build_forward (f); + gf.n_threads = n_threads; + struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false); + gb.n_threads = n_threads; ggml_graph_compute(ctx0, &gf); ggml_graph_reset (&gf);