aboutsummaryrefslogtreecommitdiff
path: root/quantize.py
blob: 641df8dda1b1e5efe1aec55082b3fa6a83b41f64 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3

"""Script to execute the "quantize" script on a given set of models."""

import subprocess
import argparse
import glob
import sys
import os


def main():
    """Update the quantize binary name depending on the platform and parse
    the command line arguments and execute the script.
    """

    if "linux" in sys.platform or "darwin" in sys.platform:
        quantize_script_binary = "quantize"

    elif "win32" in sys.platform or "cygwin" in sys.platform:
        quantize_script_binary = "quantize.exe"

    else:
        print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
        quantize_script_binary = "quantize"

    parser = argparse.ArgumentParser(
        prog='python3 quantize.py',
        description='This script quantizes the given models by applying the '
        f'"{quantize_script_binary}" script on them.'
    )
    parser.add_argument(
        'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
        help='The models to quantize.'
    )
    parser.add_argument(
        '-r', '--remove-16', action='store_true', dest='remove_f16',
        help='Remove the f16 model after quantizing it.'
    )
    parser.add_argument(
        '-m', '--models-path', dest='models_path',
        default=os.path.join(os.getcwd(), "models"),
        help='Specify the directory where the models are located.'
    )
    parser.add_argument(
        '-q', '--quantize-script-path', dest='quantize_script_path',
        default=os.path.join(os.getcwd(), quantize_script_binary),
        help='Specify the path to the "quantize" script.'
    )

    # TODO: Revise this code
    # parser.add_argument(
    #     '-t', '--threads', dest='threads', type='int',
    #     default=os.cpu_count(),
    #     help='Specify the number of threads to use to quantize many models at '
    #     'once. Defaults to os.cpu_count().'
    # )

    args = parser.parse_args()
    args.models_path = os.path.abspath(args.models_path)

    if not os.path.isfile(args.quantize_script_path):
        print(
            f'The "{quantize_script_binary}" script was not found in the '
            "current location.\nIf you want to use it from another location, "
            "set the --quantize-script-path argument from the command line."
        )
        sys.exit(1)

    for model in args.models:
        # The model is separated in various parts
        # (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
        f16_model_path_base = os.path.join(
            args.models_path, model, "ggml-model-f16.bin"
        )

        if not os.path.isfile(f16_model_path_base):
            print(f'The file %s was not found' % f16_model_path_base)
            sys.exit(1)

        f16_model_parts_paths = map(
            lambda filename: os.path.join(f16_model_path_base, filename),
            glob.glob(f"{f16_model_path_base}*")
        )

        for f16_model_part_path in f16_model_parts_paths:
            if not os.path.isfile(f16_model_part_path):
                print(
                    f"The f16 model {os.path.basename(f16_model_part_path)} "
                    f"was not found in {args.models_path}{os.path.sep}{model}"
                    ". If you want to use it from another location, set the "
                    "--models-path argument from the command line."
                )
                sys.exit(1)

            __run_quantize_script(
                args.quantize_script_path, f16_model_part_path
            )

            if args.remove_f16:
                os.remove(f16_model_part_path)


# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406

def __run_quantize_script(script_path, f16_model_part_path):
    """Run the quantize script specifying the path to it and the path to the
    f16 model to quantize.
    """

    new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
    subprocess.run(
        [script_path, f16_model_part_path, new_quantized_model_path, "2"],
        check=True
    )


if __name__ == "__main__":
    try:
        main()

    except subprocess.CalledProcessError:
        print("\nAn error ocurred while trying to quantize the models.")
        sys.exit(1)

    except KeyboardInterrupt:
        sys.exit(0)

    else:
        print("\nSuccesfully quantized all models.")