diff options
Diffstat (limited to 'convert-pth-to-ggml.py')
-rw-r--r-- | convert-pth-to-ggml.py | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index 108eb1f..c506676 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -10,12 +10,10 @@ # - Name (char[name_length]) # - Data (float[n_dims]) # -# By default, the bigger matrices are converted to 16-bit floats. -# This can be disabled by adding the "use-f32" CLI argument. -# # At the start of the ggml file we write the model parameters # and vocabulary. # + import argparse import os import sys @@ -23,6 +21,7 @@ import json import struct import numpy as np import torch + from sentencepiece import SentencePieceProcessor def parse_args(): @@ -30,6 +29,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') + parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file') return parser.parse_args() def get_n_parts(dim): @@ -134,6 +134,27 @@ def main(): ftype_str = ["f32", "f16"] hparams, tokenizer = load_hparams_and_tokenizer(dir_model) + + # if only writing vocab to file + if args.vocab_only: + + fname_model = f"{dir_model}/consolidated.00.pth" + fname_out = f"{dir_model}/ggml-vocab.bin" + + print(f"Extracting only the vocab from '{fname_model}'\n") + + model = torch.load(fname_model, map_location="cpu") + + with open(fname_out, "wb") as fout: + fout.write(struct.pack("i", hparams["vocab_size"])) + write_tokens(fout, tokenizer) + + del model + + print(f"Done. Output file: {fname_out}\n") + + return + n_parts = get_n_parts(hparams["dim"]) for p in range(n_parts): @@ -151,6 +172,7 @@ def main(): process_and_write_variables(fout, model, ftype) del model + print(f"Done. Output file: {fname_out}, (part {p})\n") if __name__ == "__main__": |