aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-06-10 01:59:17 -0600
committerGitHub <noreply@github.com>2023-06-10 10:59:17 +0300
commit4f0154b0bad775ac4651bf73b5c216eb43c45cdc (patch)
tree33a6036c589fd494af7de0cd786e395d4fd3f699 /examples
parentef3171d16241c18581d4d08374f0b9e396ade6b7 (diff)
llama : support requantizing models instead of only allowing quantization from 16/32bit (#1691)
* Add support for quantizing already quantized models * Threaded dequantizing and f16 to f32 conversion * Clean up thread blocks with spares calculation a bit * Use std::runtime_error exceptions.
Diffstat (limited to 'examples')
-rw-r--r--examples/quantize/quantize.cpp57
1 files changed, 38 insertions, 19 deletions
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 947b402..c6bf1b7 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -3,6 +3,7 @@
#include "llama.h"
#include <cstdio>
+#include <cstring>
#include <map>
#include <string>
@@ -53,27 +54,49 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st
// usage:
// ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
//
+void usage(const char * executable) {
+ fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n", executable);
+ fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
+ fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
+ fprintf(stderr, "Allowed quantization types:\n");
+ for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
+ fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
+ }
+ exit(1);
+}
+
int main(int argc, char ** argv) {
if (argc < 3) {
- fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
- for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
- fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
+ usage(argv[0]);
+ }
+
+ llama_model_quantize_params params = llama_model_quantize_default_params();
+
+ int arg_idx = 1;
+
+ for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
+ if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
+ params.quantize_output_tensor = false;
+ } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
+ params.allow_requantize = true;
+ } else {
+ usage(argv[0]);
}
- return 1;
+ }
+
+ if (argc - arg_idx < 3) {
+ usage(argv[0]);
}
llama_init_backend();
// parse command line arguments
- const std::string fname_inp = argv[1];
+ const std::string fname_inp = argv[arg_idx];
+ arg_idx++;
std::string fname_out;
- int nthread;
- llama_ftype ftype;
- int arg_idx = 2;
std::string ftype_str;
- if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
- // argv[2] is the ftype
+ if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
std::string fpath;
const size_t pos = fname_inp.find_last_of('/');
if (pos != std::string::npos) {
@@ -84,7 +107,6 @@ int main(int argc, char ** argv) {
arg_idx++;
}
else {
- // argv[2] is the output path
fname_out = argv[arg_idx];
arg_idx++;
@@ -92,8 +114,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: missing ftype\n", __func__);
return 1;
}
- // argv[3] is the ftype
- if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
+ if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
return 1;
}
@@ -103,21 +124,19 @@ int main(int argc, char ** argv) {
// parse nthreads
if (argc > arg_idx) {
try {
- nthread = std::stoi(argv[arg_idx]);
+ params.nthread = std::stoi(argv[arg_idx]);
}
catch (const std::exception & e) {
fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
return 1;
}
- } else {
- nthread = 0;
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
- if (nthread > 0) {
- fprintf(stderr, " using %d threads", nthread);
+ if (params.nthread > 0) {
+ fprintf(stderr, " using %d threads", params.nthread);
}
fprintf(stderr, "\n");
@@ -129,7 +148,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = llama_time_us();
- if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
+ if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), &params)) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}