diff options
Diffstat (limited to 'tests/test-quantize-fns.cpp')
-rw-r--r-- | tests/test-quantize-fns.cpp | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 7e091e8..a31a188 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -36,7 +36,7 @@ float array_rmse(const float * a1, const float * a2, size_t n) { // Total quantization error on test data float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { - std::vector<uint8_t> tmp_q(test_size); + std::vector<uint8_t> tmp_q(2*test_size); std::vector<float> tmp_out(test_size); qfns.quantize_row_q(test_data, tmp_q.data(), test_size); @@ -46,7 +46,7 @@ float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const fl // Total quantization error on test data float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { - std::vector<uint8_t> tmp_q(test_size); + std::vector<uint8_t> tmp_q(2*test_size); std::vector<float> tmp_out(test_size); std::vector<float> tmp_out_ref(test_size); @@ -69,10 +69,10 @@ float dot_product(const float * a1, const float * a2, size_t test_size) { // Total dot product error float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) { - std::vector<uint8_t> tmp_q1(test_size); - std::vector<uint8_t> tmp_q2(test_size*2); + std::vector<uint8_t> tmp_q1(2*test_size); + std::vector<uint8_t> tmp_q2(2*test_size); - qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size); + qfns.quantize_row_q (test_data1, tmp_q1.data(), test_size); qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size); float result = INFINITY; @@ -125,7 +125,7 @@ int main(int argc, char * argv[]) { failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR); num_failed += failed; if (failed || verbose) { - printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); + printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); } const float reference_error = reference_quantization_error(qfns, test_size, test_data.data()); @@ -139,7 +139,7 @@ int main(int argc, char * argv[]) { failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR); num_failed += failed; if (failed || verbose) { - printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); + printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); } } } |