aboutsummaryrefslogtreecommitdiff
path: root/convert-pth-to-ggml.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert-pth-to-ggml.py')
-rw-r--r--convert-pth-to-ggml.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py
index 108eb1f..c506676 100644
--- a/convert-pth-to-ggml.py
+++ b/convert-pth-to-ggml.py
@@ -10,12 +10,10 @@
# - Name (char[name_length])
# - Data (float[n_dims])
#
-# By default, the bigger matrices are converted to 16-bit floats.
-# This can be disabled by adding the "use-f32" CLI argument.
-#
# At the start of the ggml file we write the model parameters
# and vocabulary.
#
+
import argparse
import os
import sys
@@ -23,6 +21,7 @@ import json
import struct
import numpy as np
import torch
+
from sentencepiece import SentencePieceProcessor
def parse_args():
@@ -30,6 +29,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
+ parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file')
return parser.parse_args()
def get_n_parts(dim):
@@ -134,6 +134,27 @@ def main():
ftype_str = ["f32", "f16"]
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
+
+ # if only writing vocab to file
+ if args.vocab_only:
+
+ fname_model = f"{dir_model}/consolidated.00.pth"
+ fname_out = f"{dir_model}/ggml-vocab.bin"
+
+ print(f"Extracting only the vocab from '{fname_model}'\n")
+
+ model = torch.load(fname_model, map_location="cpu")
+
+ with open(fname_out, "wb") as fout:
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ write_tokens(fout, tokenizer)
+
+ del model
+
+ print(f"Done. Output file: {fname_out}\n")
+
+ return
+
n_parts = get_n_parts(hparams["dim"])
for p in range(n_parts):
@@ -151,6 +172,7 @@ def main():
process_and_write_variables(fout, model, ftype)
del model
+
print(f"Done. Output file: {fname_out}, (part {p})\n")
if __name__ == "__main__":