diff options
author | Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> | 2023-06-10 01:59:17 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-10 10:59:17 +0300 |
commit | 4f0154b0bad775ac4651bf73b5c216eb43c45cdc (patch) | |
tree | 33a6036c589fd494af7de0cd786e395d4fd3f699 /examples | |
parent | ef3171d16241c18581d4d08374f0b9e396ade6b7 (diff) |
llama : support requantizing models instead of only allowing quantization from 16/32bit (#1691)
* Add support for quantizing already quantized models
* Threaded dequantizing and f16 to f32 conversion
* Clean up thread blocks with spares calculation a bit
* Use std::runtime_error exceptions.
Diffstat (limited to 'examples')
-rw-r--r-- | examples/quantize/quantize.cpp | 57 |
1 files changed, 38 insertions, 19 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 947b402..c6bf1b7 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -3,6 +3,7 @@ #include "llama.h" #include <cstdio> +#include <cstring> #include <map> #include <string> @@ -53,27 +54,49 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st // usage: // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] // +void usage(const char * executable) { + fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n", executable); + fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); + fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); + fprintf(stderr, "Allowed quantization types:\n"); + for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { + fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + } + exit(1); +} + int main(int argc, char ** argv) { if (argc < 3) { - fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]); - for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { - fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); + usage(argv[0]); + } + + llama_model_quantize_params params = llama_model_quantize_default_params(); + + int arg_idx = 1; + + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { + if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { + params.quantize_output_tensor = false; + } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { + params.allow_requantize = true; + } else { + usage(argv[0]); } - return 1; + } + + if (argc - arg_idx < 3) { + usage(argv[0]); } llama_init_backend(); // parse command line arguments - const std::string fname_inp = argv[1]; + const std::string fname_inp = argv[arg_idx]; + arg_idx++; std::string fname_out; - int nthread; - llama_ftype ftype; - int arg_idx = 2; std::string ftype_str; - if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { - // argv[2] is the ftype + if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { std::string fpath; const size_t pos = fname_inp.find_last_of('/'); if (pos != std::string::npos) { @@ -84,7 +107,6 @@ int main(int argc, char ** argv) { arg_idx++; } else { - // argv[2] is the output path fname_out = argv[arg_idx]; arg_idx++; @@ -92,8 +114,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; } - // argv[3] is the ftype - if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) { + if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); return 1; } @@ -103,21 +124,19 @@ int main(int argc, char ** argv) { // parse nthreads if (argc > arg_idx) { try { - nthread = std::stoi(argv[arg_idx]); + params.nthread = std::stoi(argv[arg_idx]); } catch (const std::exception & e) { fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what()); return 1; } - } else { - nthread = 0; } fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); - if (nthread > 0) { - fprintf(stderr, " using %d threads", nthread); + if (params.nthread > 0) { + fprintf(stderr, " using %d threads", params.nthread); } fprintf(stderr, "\n"); @@ -129,7 +148,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = llama_time_us(); - if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { + if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ¶ms)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } |