diff options
-rw-r--r-- | convert-pth-to-ggml.py | 192 |
1 files changed, 84 insertions, 108 deletions
diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d0eb213..8194876 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -16,7 +16,7 @@ # At the start of the ggml file we write the model parameters # and vocabulary. # -import os +import argparse import sys import json import struct @@ -24,137 +24,91 @@ import numpy as np import torch from sentencepiece import SentencePieceProcessor -if len(sys.argv) < 3: - print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n") - print(" ftype == 0 -> float32") - print(" ftype == 1 -> float16") - sys.exit(1) +def parse_args(): -# output in the same directory as the model -dir_model = sys.argv[1] - -fname_hparams = sys.argv[1] + "/params.json" -fname_tokenizer = sys.argv[1] + "/../tokenizer.model" + parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') + parser.add_argument('dir_model', help='directory containing the model checkpoint') + parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') + return parser.parse_args() def get_n_parts(dim): - if dim == 4096: - return 1 - elif dim == 5120: - return 2 - elif dim == 6656: - return 4 - elif dim == 8192: - return 8 - else: - print("Invalid dim: " + str(dim)) + + mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} + n_parts = mappings.get(dim) + if n_parts is None: + print(f"Invalid dim: {dim}") sys.exit(1) -# possible data types -# ftype == 0 -> float32 -# ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] - -ftype = 1 -if len(sys.argv) > 2: - ftype = int(sys.argv[2]) - if ftype < 0 or ftype > 1: - print("Invalid ftype: " + str(ftype)) - sys.exit(1) - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" - -if os.path.exists(fname_out): - print(f"Skip conversion, it already exists: {fname_out}") - sys.exit(0) - -with open(fname_hparams, "r") as f: - hparams = json.load(f) + print(f"n_parts = {n_parts}\n") + return n_parts -tokenizer = SentencePieceProcessor(fname_tokenizer) +def load_hparams_and_tokenizer(dir_model): + + fname_hparams = f"{dir_model}/params.json" + fname_tokenizer = f"{dir_model}/../tokenizer.model" -hparams.update({"vocab_size": tokenizer.vocab_size()}) + with open(fname_hparams, "r") as f: + hparams = json.load(f) + print(hparams) -n_parts = get_n_parts(hparams["dim"]) + tokenizer = SentencePieceProcessor(fname_tokenizer) + hparams.update({"vocab_size": tokenizer.vocab_size()}) -print(hparams) -print('n_parts = ', n_parts) + return hparams, tokenizer -for p in range(n_parts): - print('Processing part ', p) +def write_header(fout, hparams, ftype): + + keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] + values = [ + 0x67676d6c, # magic: ggml in hex + *[hparams[key] for key in keys], + hparams["dim"] // hparams["n_heads"], # rot (obsolete) + ftype + ] + fout.write(struct.pack("i" * len(values), *values)) - #fname_model = sys.argv[1] + "/consolidated.00.pth" - fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" - if (p > 0): - fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) +def write_tokens(fout, tokenizer): - model = torch.load(fname_model, map_location="cpu") - - fout = open(fname_out, "wb") - - fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex - fout.write(struct.pack("i", hparams["vocab_size"])) - fout.write(struct.pack("i", hparams["dim"])) - fout.write(struct.pack("i", hparams["multiple_of"])) - fout.write(struct.pack("i", hparams["n_heads"])) - fout.write(struct.pack("i", hparams["n_layers"])) - fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) - fout.write(struct.pack("i", ftype)) - - # Is this correct?? for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): - # "<unk>" token (translated as ??) text = " \u2047 ".encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) elif tokenizer.is_control(i): - # "<s>"/"</s>" tokens - fout.write(struct.pack("i", 0)) + text = b"" elif tokenizer.is_byte(i): - # "<U+XX>" tokens (which may be invalid UTF-8) piece = tokenizer.id_to_piece(i) if len(piece) != 6: - print("Invalid token: " + piece) + print(f"Invalid token: {piece}") sys.exit(1) byte_value = int(piece[3:-1], 16) - fout.write(struct.pack("i", 1)) - fout.write(struct.pack("B", byte_value)) + text = struct.pack("B", byte_value) else: - # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - fout.write(struct.pack("i", len(text))) - fout.write(text) + fout.write(struct.pack("i", len(text))) + fout.write(text) - for k, v in model.items(): - name = k - shape = v.shape +def process_and_write_variables(fout, model, ftype): - # skip layers.X.attention.inner_attention.rope.freqs - if name[-5:] == "freqs": + for name, data in model.items(): + + if name.endswith("freqs"): continue - - print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) - - #data = tf.train.load_variable(dir_model, name).squeeze() - data = v.numpy().squeeze() - n_dims = len(data.shape); + + shape = data.shape + + print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n") + + data = np.squeeze(data) + n_dims = len(shape) # for efficiency - transpose some matrices # "model/h.*/attn/c_attn/w" # "model/h.*/attn/c_proj/w" # "model/h.*/mlp/c_fc/w" # "model/h.*/mlp/c_proj/w" - #if name[-14:] == "/attn/c_attn/w" or \ - # name[-14:] == "/attn/c_proj/w" or \ - # name[-11:] == "/mlp/c_fc/w" or \ - # name[-13:] == "/mlp/c_proj/w": - # print(" Transposing") + #if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")): + # print("Transposing") # data = data.transpose() - dshape = data.shape - # default type is fp16 ftype_cur = 1 if ftype == 0 or n_dims == 1: @@ -164,18 +118,40 @@ for p in range(n_parts): # header sname = name.encode('utf-8') - fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) - for i in range(n_dims): - fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(sname); - + fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) + for dim in reversed(data.shape): + fout.write(struct.pack("i", dim)) + fout.write(sname) + # data data.tofile(fout) - # I hope this deallocates the memory .. - model = None +def main(): + + args = parse_args() + dir_model = args.dir_model + ftype = args.ftype + ftype_str = ["f32", "f16"] + + hparams, tokenizer = load_hparams_and_tokenizer(dir_model) + n_parts = get_n_parts(hparams["dim"]) + + for p in range(n_parts): + + print(f"Processing part {p}\n") + + fname_model = f"{dir_model}/consolidated.0{p}.pth" + fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" + + model = torch.load(fname_model, map_location="cpu") + + with open(fname_out, "wb") as fout: + write_header(fout, hparams, ftype) + write_tokens(fout, tokenizer) + process_and_write_variables(fout, model, ftype) - fout.close() + del model + print(f"Done. Output file: {fname_out}, (part {p})\n") - print("Done. Output file: " + fname_out + ", (part ", p, ")") - print("") +if __name__ == "__main__": + main() |