diff options
| -rw-r--r-- | convert-pth-to-ggml.py | 192 | 
1 files changed, 84 insertions, 108 deletions
| diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d0eb213..8194876 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -16,7 +16,7 @@  # At the start of the ggml file we write the model parameters  # and vocabulary.  # -import os +import argparse  import sys  import json  import struct @@ -24,137 +24,91 @@ import numpy as np  import torch  from sentencepiece import SentencePieceProcessor -if len(sys.argv) < 3: -    print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n") -    print("  ftype == 0 -> float32") -    print("  ftype == 1 -> float16") -    sys.exit(1) +def parse_args(): -# output in the same directory as the model -dir_model = sys.argv[1] - -fname_hparams   = sys.argv[1] + "/params.json" -fname_tokenizer = sys.argv[1] + "/../tokenizer.model" +    parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') +    parser.add_argument('dir_model', help='directory containing the model checkpoint') +    parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') +    return parser.parse_args()  def get_n_parts(dim): -    if dim == 4096: -        return 1 -    elif dim == 5120: -        return 2 -    elif dim == 6656: -        return 4 -    elif dim == 8192: -        return 8 -    else: -        print("Invalid dim: " + str(dim)) +     +    mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} +    n_parts = mappings.get(dim) +    if n_parts is None: +        print(f"Invalid dim: {dim}")          sys.exit(1) -# possible data types -#   ftype == 0 -> float32 -#   ftype == 1 -> float16 -# -# map from ftype to string -ftype_str = ["f32", "f16"] - -ftype = 1 -if len(sys.argv) > 2: -    ftype = int(sys.argv[2]) -    if ftype < 0 or ftype > 1: -        print("Invalid ftype: " + str(ftype)) -        sys.exit(1) -    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" - -if os.path.exists(fname_out): -    print(f"Skip conversion, it already exists: {fname_out}") -    sys.exit(0) - -with open(fname_hparams, "r") as f: -    hparams = json.load(f) +    print(f"n_parts = {n_parts}\n") +    return n_parts -tokenizer = SentencePieceProcessor(fname_tokenizer) +def load_hparams_and_tokenizer(dir_model): +     +    fname_hparams = f"{dir_model}/params.json" +    fname_tokenizer = f"{dir_model}/../tokenizer.model" -hparams.update({"vocab_size": tokenizer.vocab_size()}) +    with open(fname_hparams, "r") as f: +        hparams = json.load(f) +        print(hparams) -n_parts = get_n_parts(hparams["dim"]) +    tokenizer = SentencePieceProcessor(fname_tokenizer) +    hparams.update({"vocab_size": tokenizer.vocab_size()}) -print(hparams) -print('n_parts = ', n_parts) +    return hparams, tokenizer -for p in range(n_parts): -    print('Processing part ', p) +def write_header(fout, hparams, ftype): +     +    keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] +    values = [ +        0x67676d6c,  # magic: ggml in hex +        *[hparams[key] for key in keys], +        hparams["dim"] // hparams["n_heads"],  # rot (obsolete) +        ftype +    ] +    fout.write(struct.pack("i" * len(values), *values)) -    #fname_model = sys.argv[1] + "/consolidated.00.pth" -    fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth" -    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" -    if (p > 0): -        fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p) +def write_tokens(fout, tokenizer): -    model = torch.load(fname_model, map_location="cpu") - -    fout = open(fname_out, "wb") - -    fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex -    fout.write(struct.pack("i", hparams["vocab_size"])) -    fout.write(struct.pack("i", hparams["dim"])) -    fout.write(struct.pack("i", hparams["multiple_of"])) -    fout.write(struct.pack("i", hparams["n_heads"])) -    fout.write(struct.pack("i", hparams["n_layers"])) -    fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete) -    fout.write(struct.pack("i", ftype)) - -    # Is this correct??      for i in range(tokenizer.vocab_size()):          if tokenizer.is_unknown(i): -            # "<unk>" token (translated as ??)              text = " \u2047 ".encode("utf-8") -            fout.write(struct.pack("i", len(text))) -            fout.write(text)          elif tokenizer.is_control(i): -            # "<s>"/"</s>" tokens -            fout.write(struct.pack("i", 0)) +            text = b""          elif tokenizer.is_byte(i): -            # "<U+XX>" tokens (which may be invalid UTF-8)              piece = tokenizer.id_to_piece(i)              if len(piece) != 6: -                print("Invalid token: " + piece) +                print(f"Invalid token: {piece}")                  sys.exit(1)              byte_value = int(piece[3:-1], 16) -            fout.write(struct.pack("i", 1)) -            fout.write(struct.pack("B", byte_value)) +            text = struct.pack("B", byte_value)          else: -            # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.              text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") -            fout.write(struct.pack("i", len(text))) -            fout.write(text) +        fout.write(struct.pack("i", len(text))) +        fout.write(text) -    for k, v in model.items(): -        name = k -        shape = v.shape +def process_and_write_variables(fout, model, ftype): -        # skip layers.X.attention.inner_attention.rope.freqs -        if name[-5:] == "freqs": +    for name, data in model.items(): +     +        if name.endswith("freqs"):              continue - -        print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype) - -        #data = tf.train.load_variable(dir_model, name).squeeze() -        data = v.numpy().squeeze() -        n_dims = len(data.shape); +         +        shape = data.shape +         +        print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n") +         +        data = np.squeeze(data) +        n_dims = len(shape)          # for efficiency - transpose some matrices          # "model/h.*/attn/c_attn/w"          # "model/h.*/attn/c_proj/w"          # "model/h.*/mlp/c_fc/w"          # "model/h.*/mlp/c_proj/w" -        #if name[-14:] == "/attn/c_attn/w" or \ -        #   name[-14:] == "/attn/c_proj/w" or \ -        #   name[-11:] == "/mlp/c_fc/w" or \ -        #   name[-13:] == "/mlp/c_proj/w": -        #    print("  Transposing") +        #if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")): +        #    print("Transposing")          #    data = data.transpose() -        dshape = data.shape -          # default type is fp16          ftype_cur = 1          if ftype == 0 or n_dims == 1: @@ -164,18 +118,40 @@ for p in range(n_parts):          # header          sname = name.encode('utf-8') -        fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) -        for i in range(n_dims): -            fout.write(struct.pack("i", dshape[n_dims - 1 - i])) -        fout.write(sname); - +        fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur)) +        for dim in reversed(data.shape): +            fout.write(struct.pack("i", dim)) +        fout.write(sname) +                  # data          data.tofile(fout) -    # I hope this deallocates the memory .. -    model = None +def main(): + +    args = parse_args() +    dir_model = args.dir_model +    ftype = args.ftype +    ftype_str = ["f32", "f16"] + +    hparams, tokenizer = load_hparams_and_tokenizer(dir_model) +    n_parts = get_n_parts(hparams["dim"]) + +    for p in range(n_parts): + +        print(f"Processing part {p}\n") +         +        fname_model = f"{dir_model}/consolidated.0{p}.pth" +        fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}" + +        model = torch.load(fname_model, map_location="cpu") + +        with open(fname_out, "wb") as fout: +            write_header(fout, hparams, ftype) +            write_tokens(fout, tokenizer) +            process_and_write_variables(fout, model, ftype) -    fout.close() +        del model +        print(f"Done. Output file: {fname_out}, (part {p})\n") -    print("Done. Output file: " + fname_out + ", (part ", p, ")") -    print("") +if __name__ == "__main__": +    main() | 
