diff options
author | unbounded <haakon@likedan.net> | 2023-04-22 11:10:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-22 12:10:39 +0300 |
commit | 5f939498d517b4dddbe904f202e895a3ecfb9dc4 (patch) | |
tree | 010b9312921f4234bda24d3ff10aba86e1d3b6ac /tests | |
parent | 36b4f7e06406eed8a605cc9f2921d9244ef6a8e5 (diff) |
ggml : unit test for quantization functions (#953)
* Unit test for quantization functions
Use the ggml_internal_get_quantize_fn function to loop through all
quantization formats and run a sanity check on the result.
Also add a microbenchmark that times these functions directly without
running the rest of the GGML graph.
* test-quantize-fns: CI fixes
Fix issues uncovered in CI
- need to use sizes divisible by 32*8 for loop unrolling
- use intrinsic header that should work on Mac
* test-quantize: remove
Per PR comment, subsumed by test-quantize-fns
* test-quantize: fix for q8_0 intermediates
Diffstat (limited to 'tests')
-rw-r--r-- | tests/CMakeLists.txt | 3 | ||||
-rw-r--r-- | tests/test-quantize-fns.cpp | 154 | ||||
-rw-r--r-- | tests/test-quantize-perf.cpp | 310 | ||||
-rw-r--r-- | tests/test-quantize.c | 42 |
4 files changed, 466 insertions, 43 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 157d733..81eadbc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,5 +6,6 @@ function(llama_add_test source) endfunction() # llama_add_test(test-double-float.c) # SLOW -llama_add_test(test-quantize.c) +llama_add_test(test-quantize-fns.cpp) +llama_add_test(test-quantize-perf.cpp) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp new file mode 100644 index 0000000..5a54101 --- /dev/null +++ b/tests/test-quantize-fns.cpp @@ -0,0 +1,154 @@ +// Unit tests for quantization specific functions - quantize, dequantize and dot product + +#include "ggml.h" + +#undef NDEBUG +#include <assert.h> +#include <math.h> +#include <stdio.h> +#include <string> +#include <vector> + + +const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001; +const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002; +const float MAX_DOT_PRODUCT_ERROR = 0.02; + +const char* RESULT_STR[] = {"ok", "FAILED"}; + + +// Generate synthetic data +void generate_data(float offset, size_t n, float * dst) { + for (size_t i = 0; i < n; i++) { + dst[i] = 0.1 + 2*cosf(i + offset); + } +} + +// Calculate RMSE between two float arrays +float array_rmse(const float * a1, const float * a2, size_t n) { + double sum = 0; + for (size_t i = 0; i < n; i++) { + double diff = a1[i] - a2[i]; + sum += diff * diff; + } + return sqrtf(sum) / n; +} + +// Total quantization error on test data +float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { + std::vector<uint8_t> tmp_q(test_size); + std::vector<float> tmp_out(test_size); + + qfns.quantize_row_q(test_data, tmp_q.data(), test_size); + qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size); + return array_rmse(test_data, tmp_out.data(), test_size); +} + +// Total quantization error on test data +float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { + std::vector<uint8_t> tmp_q(test_size); + std::vector<float> tmp_out(test_size); + std::vector<float> tmp_out_ref(test_size); + + qfns.quantize_row_q(test_data, tmp_q.data(), test_size); + qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size); + + qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size); + qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size); + + return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size); +} + +float dot_product(const float * a1, const float * a2, size_t test_size) { + double sum = 0; + for (size_t i = 0; i < test_size; i++) { + sum += a1[i] * a2[i]; + } + return sum; +} + +// Total dot product error +float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) { + std::vector<uint8_t> tmp_q1(test_size); + std::vector<uint8_t> tmp_q2(test_size*2); + + qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size); + qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size); + + float result = INFINITY; + qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data()); + + const float dot_ref = dot_product(test_data1, test_data2, test_size); + + return fabsf(result - dot_ref) / test_size; +} + +int main(int argc, char * argv[]) { + bool verbose = false; + const size_t test_size = 32 * 128; + + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + + if (arg == "-v") { + verbose = true; + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + return 1; + } + } + + std::vector<float> test_data(test_size); + std::vector<float> test_data2(test_size); + + generate_data(0.0, test_data.size(), test_data.data()); + generate_data(1.0, test_data2.size(), test_data2.data()); + + // Initialize GGML, ensures float conversion tables are initialized + struct ggml_init_params ggml_params = { + /* .mem_size = */ 1*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true, + }; + struct ggml_context * ctx = ggml_init(ggml_params); + + int num_failed = 0; + bool failed = false; + + for (int i = 0; i < GGML_TYPE_COUNT; i++) { + ggml_type type = (ggml_type) i; + quantize_fns_t qfns = ggml_internal_get_quantize_fn(i); + + if (qfns.quantize_row_q) { + const float total_error = total_quantization_error(qfns, test_size, test_data.data()); + failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR); + num_failed += failed; + if (failed || verbose) { + printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); + } + + const float reference_error = reference_quantization_error(qfns, test_size, test_data.data()); + failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR); + num_failed += failed; + if (failed || verbose) { + printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error); + } + + const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data()); + failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR); + num_failed += failed; + if (failed || verbose) { + printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); + } + } + } + + if (num_failed || verbose) { + printf("%d tests failed\n", num_failed); + } + + ggml_free(ctx); + + return num_failed > 0; +} diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp new file mode 100644 index 0000000..883df05 --- /dev/null +++ b/tests/test-quantize-perf.cpp @@ -0,0 +1,310 @@ +// Benchmark quantization specific functions on synthetic data + +#include "ggml.h" + +#undef NDEBUG +#include <algorithm> +#include <assert.h> +#include <functional> +#include <inttypes.h> +#include <math.h> +#include <memory> +#include <stdio.h> +#include <string> +#include <vector> + +#define MAX_ALIGNMENT 64 +#define QK 32 +#define WARMUP 5 +#define ITERATIONS 10 + +#define L1_SIZE 32*128 +#define L2_SIZE 32*2048 +#define L3_SIZE 32*20480 +#define MEM_SIZE 32*2048000 + +struct quantize_perf_params { + std::vector<std::string> include_types; + std::vector<size_t> test_sizes; + size_t alignment_offset = 0; + bool op_quantize_row_q_reference = false; + bool op_quantize_row_q = false; + bool op_dequantize_row_q = false; + bool op_quantize_row_q_dot = false; + bool op_vec_dot_q = false; +}; + + +#if defined(__x86_64__) || defined(__i386__) + +#include <x86intrin.h> +inline int64_t cpu_cycles() { +// Rough way to detect new-ish CPUs +#ifdef __POPCNT__ + unsigned int dummy; + return __rdtscp(&dummy); +#else + return __rdtsc(); +#endif +} + +#else + +#define cpu_cycles() 0 + +#endif + + +// Generate synthetic data +void generate_data(float offset, size_t n, float * dst) { + for (size_t i = 0; i < n; i++) { + dst[i] = 0.1 + 2*cosf(i + offset); + } +} + +float gigabytes_per_second(size_t bytes, int64_t usecs) { + return bytes / (float) usecs * 1000000 / (1024*1024*1024); +} + +void * align_with_offset(void * ptr, int offset) { + size_t dummy_size = MAX_ALIGNMENT * 4; + return (char *) std::align(MAX_ALIGNMENT, MAX_ALIGNMENT, ptr, dummy_size) + offset; +} + +void benchmark_function(size_t size, size_t q_size, std::function<size_t(void)> function) { + int64_t min_time_us = INT64_MAX; + int64_t total_time_us = 0; + int64_t min_time_cycles = INT64_MAX; + int64_t total_time_cycles = 0; + + for (int i = 0; i < WARMUP; i++) { + function(); + } + + + for (int i = 0; i < ITERATIONS; i++) { + const int64_t start_time = ggml_time_us(); + const int64_t start_cycles = cpu_cycles(); + + function(); + + const int64_t end_cycles = cpu_cycles(); + const int64_t end_time = ggml_time_us(); + + total_time_cycles += end_cycles - start_cycles; + min_time_cycles = std::min(min_time_cycles, end_cycles - start_cycles); + total_time_us += end_time - start_time; + min_time_us = std::min(min_time_us, end_time - start_time); + } + + printf(" min cycles/%d vals : %9.2f\n", QK, QK * min_time_cycles / (float) size); + printf(" avg cycles/%d vals : %9.2f\n", QK, QK * total_time_cycles / (float) (size * ITERATIONS)); + printf(" float32 throughput : %9.2f GB/s\n", gigabytes_per_second(4 * size * ITERATIONS, total_time_us)); + printf(" quantized throughput : %9.2f GB/s\n", gigabytes_per_second(q_size * ITERATIONS, total_time_us)); +} + +int main(int argc, char * argv[]) { + quantize_perf_params params {}; + + // read command line + + bool invalid_param = false; + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + + if (arg == "--size") { + if (++i >= argc) { + invalid_param = true; + break; + } + size_t size = std::stoi(argv[i]); + if (size % 32 != 0) { + fprintf(stderr, "error: size %zu not divisible by 32\n", size); + invalid_param = true; + break; + } + params.test_sizes.push_back(size); + } else if (arg == "-3") { + // quick select sizes that probably fit in CPU caches + params.test_sizes.push_back(L1_SIZE); + params.test_sizes.push_back(L2_SIZE); + params.test_sizes.push_back(L3_SIZE); + } else if (arg == "-4") { + // quick select cache sizes + memory + params.test_sizes.push_back(L1_SIZE); + params.test_sizes.push_back(L2_SIZE); + params.test_sizes.push_back(L3_SIZE); + params.test_sizes.push_back(MEM_SIZE); + } else if (arg == "--op") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string op {argv[i]}; + if (op == "quantize_row_q_reference") { + params.op_quantize_row_q_reference = true; + } else if (op == "quantize_row_q") { + params.op_quantize_row_q = true; + } else if (op == "dequantize_row_q") { + params.op_dequantize_row_q = true; + } else if (op == "quantize_row_q_dot") { + params.op_quantize_row_q_dot = true; + } else if (op == "vec_dot_q") { + params.op_vec_dot_q = true; + } else { + invalid_param = true; + break; + } + } else if (arg == "--type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.include_types.push_back(argv[i]); + } else if (arg == "--alignment-offset") { + if (++i >= argc) { + invalid_param = true; + break; + } + int alignment = std::stoi(argv[i]); + if (alignment < 0 || alignment > MAX_ALIGNMENT) { + fprintf(stderr, "error: aligment-offset must be less than %d\n", MAX_ALIGNMENT); + invalid_param = true; + break; + } + params.alignment_offset = alignment; + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + return 1; + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + return 1; + } + + if (params.test_sizes.empty()) { + params.test_sizes.push_back(L1_SIZE); + } + if (!(params.op_quantize_row_q_reference || params.op_quantize_row_q || params.op_dequantize_row_q || params.op_quantize_row_q_dot || params.op_vec_dot_q)) { + params.op_quantize_row_q_reference = params.op_quantize_row_q = params.op_dequantize_row_q = params.op_quantize_row_q_dot = params.op_vec_dot_q = true; + } + + std::sort(params.test_sizes.begin(), params.test_sizes.end()); + size_t largest = params.test_sizes.back(); + + std::vector<uint8_t> test_data1_v(largest*4 + MAX_ALIGNMENT*2); + std::vector<uint8_t> test_data2_v(largest*4 + MAX_ALIGNMENT*2); + std::vector<uint8_t> test_q1_v(largest*4 + MAX_ALIGNMENT*2); + std::vector<uint8_t> test_q2_v(largest*4 + MAX_ALIGNMENT*2); + std::vector<uint8_t> test_out_v(largest*4 + MAX_ALIGNMENT*2); + + float * test_data1 = (float *) align_with_offset(test_data1_v.data(), params.alignment_offset); + float * test_data2 = (float *) align_with_offset(test_data2_v.data(), params.alignment_offset); + float * test_q1 = (float *) align_with_offset(test_q1_v.data(), params.alignment_offset); + float * test_q2 = (float *) align_with_offset(test_q2_v.data(), params.alignment_offset); + float * test_out = (float *) align_with_offset(test_out_v.data(), params.alignment_offset); + + generate_data(0, largest, test_data1); + generate_data(1, largest, test_data2); + + + // Initialize GGML, ensures float conversion tables are initialized + struct ggml_init_params ggml_params = { + /* .mem_size = */ 1*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ true, + }; + struct ggml_context * ctx = ggml_init(ggml_params); + + for (int i = 0; i < GGML_TYPE_COUNT; i++) { + ggml_type type = (ggml_type) i; + quantize_fns_t qfns = ggml_internal_get_quantize_fn(i); + if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) { + continue; + } + + if (qfns.quantize_row_q) { + printf("%s\n", ggml_type_name(type)); + + if (params.op_quantize_row_q_reference) { + printf(" quantize_row_q_reference\n"); + for (size_t size : params.test_sizes) { + printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); + auto quantize_fn = [&](void ) { + qfns.quantize_row_q_reference(test_data1, test_q1, size); + return test_q1[0]; + }; + size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type); + benchmark_function(size, quantized_size, quantize_fn); + } + printf("\n"); + } + + if (params.op_quantize_row_q) { + printf(" quantize_row_q\n"); + for (size_t size : params.test_sizes) { + printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); + auto quantize_fn = [&](void ) { + qfns.quantize_row_q(test_data1, test_q1, size); + return test_q1[0]; + }; + size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type); + benchmark_function(size, quantized_size, quantize_fn); + } + printf("\n"); + } + + if (params.op_dequantize_row_q) { + printf(" dequantize_row_q\n"); + qfns.quantize_row_q(test_data1, test_q1, largest); + for (size_t size : params.test_sizes) { + printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); + auto quantize_fn = [&](void ) { + qfns.dequantize_row_q(test_q1, test_out, size); + return test_out[0]; + }; + size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type); + benchmark_function(size, quantized_size, quantize_fn); + } + printf("\n"); + } + + if (params.op_quantize_row_q_dot) { + printf(" quantize_row_q_dot\n"); + for (size_t size : params.test_sizes) { + printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); + auto quantize_fn = [&](void ) { + qfns.quantize_row_q_dot(test_data1, test_q1, size); + return test_q1[0]; + }; + size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type); + benchmark_function(size, quantized_size, quantize_fn); + } + printf("\n"); + } + + if (params.op_vec_dot_q) { + printf(" vec_dot_q\n"); + qfns.quantize_row_q(test_data1, test_q1, largest); + qfns.quantize_row_q(test_data2, test_q2, largest); + for (size_t size : params.test_sizes) { + printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); + auto quantize_fn = [&](void ) { + float result; + qfns.vec_dot_q(size, &result, test_q1, test_q2); + return result; + }; + size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type); + benchmark_function(size, quantized_size, quantize_fn); + } + printf("\n"); + } + } + } + + ggml_free(ctx); + + return 0; +} diff --git a/tests/test-quantize.c b/tests/test-quantize.c deleted file mode 100644 index 993e9dc..0000000 --- a/tests/test-quantize.c +++ /dev/null @@ -1,42 +0,0 @@ -#include "ggml.h" -#undef NDEBUG -#include <assert.h> -#include <math.h> - -int main(void) { - #define QK 32 - float src[QK]; - uint8_t dst[24]; - int64_t hist[16]; - - for (int i = 0; i < QK; i++) { - src[i] = (float)(i + 1); - } - - size_t size = ggml_quantize_q4_0(src, dst, QK, QK, hist); - assert(size == 20); - float max_result = ((float *)dst)[0]; - float max_expected = src[31] / ((1 << 3) - 1); - assert(max_result == max_expected); - for (int i = 0; i < QK; i++) { - uint8_t q4_result = (i % 2) ? (dst[sizeof(float) + i/2] >> 4) : (dst[sizeof(float) + i/2] & 0xF); - uint8_t q4_expected = roundf(src[i] / max_expected) + 8; - assert(q4_result == q4_expected); - } - - size = ggml_quantize_q4_1(src, dst, QK, QK, hist); - assert(size == 24); - float delta_result = ((float *)dst)[0]; - float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1); - assert(delta_result == delta_expected); - float min_result = ((float *)dst)[1]; - float min_expected = src[0]; - assert(min_result == min_expected); - for (int i = 0; i < QK; i++) { - uint8_t q4_result = (i % 2) ? (dst[sizeof(float)*2 + i/2] >> 4) : (dst[sizeof(float)*2 + i/2] & 0xF); - uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected); - assert(q4_result == q4_expected); - } - - return 0; -} |