aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--quantize.py126
-rwxr-xr-xquantize.sh15
3 files changed, 127 insertions, 16 deletions
diff --git a/README.md b/README.md
index 504c101..dae1bf1 100644
--- a/README.md
+++ b/README.md
@@ -147,7 +147,7 @@ python3 -m pip install torch numpy sentencepiece
python3 convert-pth-to-ggml.py models/7B/ 1
# quantize the model to 4-bits
-./quantize.sh 7B
+python3 quantize.py 7B
# run the inference
./main -m ./models/7B/ggml-model-q4_0.bin -n 128
diff --git a/quantize.py b/quantize.py
new file mode 100644
index 0000000..6320b0a
--- /dev/null
+++ b/quantize.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+
+"""Script to execute the "quantize" script on a given set of models."""
+
+import subprocess
+import argparse
+import glob
+import sys
+import os
+
+
+def main():
+ """Update the quantize binary name depending on the platform and parse
+ the command line arguments and execute the script.
+ """
+
+ if "linux" in sys.platform or "darwin" in sys.platform:
+ quantize_script_binary = "quantize"
+
+ elif "win32" in sys.platform or "cygwin" in sys.platform:
+ quantize_script_binary = "quantize.exe"
+
+ else:
+ print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
+ quantize_script_binary = "quantize"
+
+ parser = argparse.ArgumentParser(
+ prog='python3 quantize.py',
+ description='This script quantizes the given models by applying the '
+ f'"{quantize_script_binary}" script on them.'
+ )
+ parser.add_argument(
+ 'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
+ help='The models to quantize.'
+ )
+ parser.add_argument(
+ '-r', '--remove-16', action='store_true', dest='remove_f16',
+ help='Remove the f16 model after quantizing it.'
+ )
+ parser.add_argument(
+ '-m', '--models-path', dest='models_path',
+ default=os.path.join(os.getcwd(), "models"),
+ help='Specify the directory where the models are located.'
+ )
+ parser.add_argument(
+ '-q', '--quantize-script-path', dest='quantize_script_path',
+ default=os.path.join(os.getcwd(), quantize_script_binary),
+ help='Specify the path to the "quantize" script.'
+ )
+
+ # TODO: Revise this code
+ # parser.add_argument(
+ # '-t', '--threads', dest='threads', type='int',
+ # default=os.cpu_count(),
+ # help='Specify the number of threads to use to quantize many models at '
+ # 'once. Defaults to os.cpu_count().'
+ # )
+
+ args = parser.parse_args()
+
+ if not os.path.isfile(args.quantize_script_path):
+ print(
+ f'The "{quantize_script_binary}" script was not found in the '
+ "current location.\nIf you want to use it from another location, "
+ "set the --quantize-script-path argument from the command line."
+ )
+ sys.exit(1)
+
+ for model in args.models:
+ # The model is separated in various parts
+ # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
+ f16_model_path_base = os.path.join(
+ args.models_path, model, "ggml-model-f16.bin"
+ )
+
+ f16_model_parts_paths = map(
+ lambda filename: os.path.join(f16_model_path_base, filename),
+ glob.glob(f"{f16_model_path_base}*")
+ )
+
+ for f16_model_part_path in f16_model_parts_paths:
+ if not os.path.isfile(f16_model_part_path):
+ print(
+ f"The f16 model {os.path.basename(f16_model_part_path)} "
+ f"was not found in {args.models_path}{os.path.sep}{model}"
+ ". If you want to use it from another location, set the "
+ "--models-path argument from the command line."
+ )
+ sys.exit(1)
+
+ __run_quantize_script(
+ args.quantize_script_path, f16_model_part_path
+ )
+
+ if args.remove_f16:
+ os.remove(f16_model_part_path)
+
+
+# This was extracted to a top-level function for parallelization, if
+# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406
+
+def __run_quantize_script(script_path, f16_model_part_path):
+ """Run the quantize script specifying the path to it and the path to the
+ f16 model to quantize.
+ """
+
+ new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
+ subprocess.run(
+ [script_path, f16_model_part_path, new_quantized_model_path, "2"],
+ check=True
+ )
+
+
+if __name__ == "__main__":
+ try:
+ main()
+
+ except subprocess.CalledProcessError:
+ print("\nAn error ocurred while trying to quantize the models.")
+ sys.exit(1)
+
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+ else:
+ print("\nSuccesfully quantized all models.")
diff --git a/quantize.sh b/quantize.sh
deleted file mode 100755
index 6194649..0000000
--- a/quantize.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-#!/usr/bin/env bash
-
-if ! [[ "$1" =~ ^[0-9]{1,2}B$ ]]; then
- echo
- echo "Usage: quantize.sh 7B|13B|30B|65B [--remove-f16]"
- echo
- exit 1
-fi
-
-for i in `ls models/$1/ggml-model-f16.bin*`; do
- ./quantize "$i" "${i/f16/q4_0}" 2
- if [[ "$2" == "--remove-f16" ]]; then
- rm "$i"
- fi
-done