aboutsummaryrefslogtreecommitdiff
path: root/convert-gptq-to-ggml.py
blob: 7fccb4d569d8fbe50bc5a6cad613e77e12f8f739 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Convert a GPTQ quantized LLaMA model to a ggml compatible file
# Based on: https://github.com/qwopqwop200/GPTQ-for-LLaMa
#
import os
import re
import sys
import json
import struct
import numpy as np
import torch
from sentencepiece import SentencePieceProcessor

if len(sys.argv) != 4:
    print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n")
    sys.exit(1)

fname_model = sys.argv[1]
fname_tokenizer = sys.argv[2]
dir_out = sys.argv[3]

model = torch.load(fname_model, map_location="cpu")

n_vocab, n_embd = model['model.embed_tokens.weight'].shape
n_layer = 1 + max(int(m.group(1)) for name in model
                  if (m := re.match(r'model\.layers\.([0-9]+)', name)))

# hardcoded:
n_mult = 256
n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer]

tokenizer = SentencePieceProcessor(fname_tokenizer)

assert tokenizer.vocab_size() == n_vocab

fname_out = sys.argv[3]

fout = open(fname_out, "wb")

fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
fout.write(struct.pack("i", n_vocab))
fout.write(struct.pack("i", n_embd))
fout.write(struct.pack("i", n_mult))
fout.write(struct.pack("i", n_head))
fout.write(struct.pack("i", n_layer))
fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete)
fout.write(struct.pack("i", 4))


# This loop unchanged from convert-pth-to-ggml.py:
for i in range(tokenizer.vocab_size()):
    if tokenizer.is_unknown(i):
        # "<unk>" token (translated as ??)
        text = " \u2047 ".encode("utf-8")
        fout.write(struct.pack("i", len(text)))
        fout.write(text)
    elif tokenizer.is_control(i):
        # "<s>"/"</s>" tokens
        fout.write(struct.pack("i", 0))
    elif tokenizer.is_byte(i):
        # "<U+XX>" tokens (which may be invalid UTF-8)
        piece = tokenizer.id_to_piece(i)
        if len(piece) != 6:
            print("Invalid token: " + piece)
            sys.exit(1)
        byte_value = int(piece[3:-1], 16)
        fout.write(struct.pack("i", 1))
        fout.write(struct.pack("B", byte_value))
    else:
        # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
        text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
        fout.write(struct.pack("i", len(text)))
        fout.write(text)

def write_header(shape, dst_name, ftype_cur):
    sname = dst_name.encode('utf-8')
    fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
    fout.write(struct.pack("i" * len(shape), *shape[::-1]))
    fout.write(sname)

def convert_non_q4(src_name, dst_name):
    v = model[src_name]
    shape = v.shape
    print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype)
    if len(shape) == 1:
        print("  Converting to float32")
        v = v.to(torch.float32)

    ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype]

    # header
    write_header(shape, dst_name, ftype_cur)

    # data
    v.numpy().tofile(fout)

def convert_q4(src_name, dst_name, permute=False):
    zeros = model[f"{src_name}.zeros"].numpy()
    scales = model[f"{src_name}.scales"].numpy()
    bias = model[f"{src_name}.bias"].numpy()
    qweight = model[f"{src_name}.qweight"].numpy().T # transpose

    # Q4_1 does not support bias; good thing the bias is always all zeros.
    assert not np.any(bias)

    # Each int32 item is actually 8 int4 items packed together, and it's transposed.
    shape = (qweight.shape[0], qweight.shape[1] * 8)

    print("Processing Q4 variable: " + src_name + " with shape: ", shape)

    # The output format has the int4 weights in groups of 32 rather than 8.
    # It looks like this:
    # For each row:
    #   For each group of 32 columns:
    #     - addend (float32, 4 bytes)
    #     - scale (float32, 4 bytes)
    #     - weights (int4 * 32, 16 bytes)
    # Note that in the input, the scales and addends are shared between all
    # the columns in a row, so we end up wasting quite a bit of memory with
    # repeated scales and addends.

    addends = -zeros # flip sign

    # Since the output format is mixed between integers and floats, we have
    # to hackily view the floats as int32s just so numpy will let us
    # concatenate them.
    addends_view = addends.view(dtype=np.int32)
    scales_view = scales.view(dtype=np.int32)

    # Split into groups of 4 columns (i.e. 32 columns of quantized data):
    grouped = qweight.reshape([qweight.shape[0], qweight.shape[1] // 4, 4])

    # Repeat addends and scales:
    addends_rep = np.atleast_3d(addends_view).repeat(grouped.shape[1], axis=1)
    scales_rep = np.atleast_3d(scales_view).repeat(grouped.shape[1], axis=1)

    blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting='no')

    if permute:
        # Permute some rows to undo the permutation done by convert_llama_weights_to_hf.py.
        # This can be done after the above conversion because it doesn't affect column order/layout.
        blob = (blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:])
                    .swapaxes(1, 2)
                    .reshape(blob.shape))

    # header
    write_header(shape, dst_name, 3) # ftype = Q4_1

    # data
    blob.tofile(fout)

convert_non_q4("model.embed_tokens.weight", "tok_embeddings.weight")
convert_non_q4("model.norm.weight", "norm.weight")
convert_non_q4("lm_head.weight", "output.weight")

for i in range(n_layer):
    convert_q4(f"model.layers.{i}.self_attn.q_proj", f"layers.{i}.attention.wq.weight", permute=True)
    convert_q4(f"model.layers.{i}.self_attn.k_proj", f"layers.{i}.attention.wk.weight", permute=True)
    convert_q4(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight")
    convert_q4(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight")

    convert_q4(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight")
    convert_q4(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight")
    convert_q4(f"model.layers.{i}.mlp.up_proj",   f"layers.{i}.feed_forward.w3.weight")

    convert_non_q4(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight")
    convert_non_q4(f"model.layers.{i}.post_attention_layernorm.weight", f"layers.{i}.ffn_norm.weight")


fout.close()

print("Done. Output file: " + fname_out)
print("")