aboutsummaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert.py')
-rw-r--r--convert.py193
1 files changed, 128 insertions, 65 deletions
diff --git a/convert.py b/convert.py
index 7a2705e..f3bf179 100644
--- a/convert.py
+++ b/convert.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python
import argparse
import concurrent.futures
import copy
@@ -132,7 +133,7 @@ TENSORS_SET = set(TENSORS_LIST)
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
- for n_mult in range(256, 1, -1):
+ for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
@@ -140,11 +141,12 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
@dataclass
class Params:
- n_vocab: int
- n_embd: int
- n_mult: int
- n_head: int
- n_layer: int
+ n_vocab: int
+ n_embd: int
+ n_mult: int
+ n_head: int
+ n_layer: int
+ n_kv_head: Optional[int] # This parameter is only used for Llama 2
@staticmethod
def guessed(model: 'LazyModel') -> 'Params':
@@ -166,11 +168,12 @@ class Params:
n_head=n_embd // 128 # guessed
return Params(
- n_vocab=n_vocab,
- n_embd=n_embd,
- n_mult=256,
- n_head=n_head,
- n_layer=n_layer,
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = 256,
+ n_head = n_head,
+ n_layer = n_layer,
+ n_kv_head = None,
)
@staticmethod
@@ -178,28 +181,56 @@ class Params:
config = json.load(open(config_path))
n_vocab = config["vocab_size"];
- n_embd = config["hidden_size"];
- n_head = config["num_attention_heads"];
+ n_embd = config["hidden_size"];
+ n_head = config["num_attention_heads"];
n_layer = config["num_hidden_layers"];
- n_ff = config["intermediate_size"];
+ n_ff = config["intermediate_size"];
+ n_kv_head = config.get("num_key_value_heads")
n_mult = find_n_mult(n_ff, n_embd);
return Params(
- n_vocab=n_vocab,
- n_embd=n_embd,
- n_mult=n_mult,
- n_head=n_head,
- n_layer=n_layer,
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = n_mult,
+ n_head = n_head,
+ n_layer = n_layer,
+ n_kv_head = n_kv_head,
+ )
+
+ # LLaMA v2 70B params.json
+ # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
+ @staticmethod
+ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
+ config = json.load(open(config_path))
+
+ n_vocab = config["vocab_size"];
+ n_embd = config["dim"];
+ n_head = config["n_heads"];
+ n_layer = config["n_layers"];
+ n_mult = config["multiple_of"];
+
+ if n_vocab == -1:
+ n_vocab = model["tok_embeddings.weight"].shape[0]
+
+ return Params(
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = n_mult,
+ n_head = n_head,
+ n_layer = n_layer,
+ n_kv_head = None,
)
@staticmethod
def load(model_plus: 'ModelPlus') -> 'Params':
+ hf_config_path = model_plus.paths[0].parent / "config.json"
orig_config_path = model_plus.paths[0].parent / "params.json"
- hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
- if hf_transformer_config_path.exists():
- params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
+ if hf_config_path.exists():
+ params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
+ elif orig_config_path.exists():
+ params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
else:
params = Params.guessed(model_plus.model)
@@ -208,14 +239,21 @@ class Params:
class SentencePieceVocab:
- def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
- self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
+ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
+ self.vocabtype = vocabtype
+ if self.vocabtype == "bpe":
+ self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
+ else:
+ self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens))
else:
added_tokens = {}
- vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
+ if self.vocabtype == "bpe":
+ vocab_size: int = len(self.sentencepiece_tokenizer)
+ else:
+ vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids:
@@ -229,22 +267,32 @@ class SentencePieceVocab:
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
tokenizer = self.sentencepiece_tokenizer
- for i in range(tokenizer.vocab_size()):
+ if self.vocabtype == "bpe":
+ from transformers.models.gpt2 import tokenization_gpt2
+ byte_encoder = tokenization_gpt2.bytes_to_unicode()
+ byte_decoder = {v: k for k, v in byte_encoder.items()}
+ for i, item in enumerate(tokenizer):
text: bytes
- if tokenizer.is_unknown(i):
- text = " \u2047 ".encode("utf-8")
- elif tokenizer.is_control(i):
- text = b""
- elif tokenizer.is_byte(i):
- piece = tokenizer.id_to_piece(i)
- if len(piece) != 6:
- raise Exception(f"Invalid token: {piece}")
- byte_value = int(piece[3:-1], 16)
- text = struct.pack("B", byte_value)
- else:
- text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
- score: float = tokenizer.get_score(i)
+ text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
+ score: float = -i
yield text, score
+ else:
+ for i in range(tokenizer.vocab_size()):
+ text: bytes
+ if tokenizer.is_unknown(i):
+ text = " \u2047 ".encode("utf-8")
+ elif tokenizer.is_control(i):
+ text = b""
+ elif tokenizer.is_byte(i):
+ piece = tokenizer.id_to_piece(i)
+ if len(piece) != 6:
+ raise Exception(f"Invalid token: {piece}")
+ byte_value = int(piece[3:-1], 16)
+ text = struct.pack("B", byte_value)
+ else:
+ text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
+ score: float = tokenizer.get_score(i)
+ yield text, score
def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
for text in self.added_tokens_list:
@@ -274,10 +322,12 @@ class GGMLVocab:
Vocab = Union[SentencePieceVocab, GGMLVocab]
-def permute(weights: NDArray, n_head: int) -> NDArray:
+def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
+ if n_kv_head is not None and n_head != n_kv_head:
+ n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
- .swapaxes(1, 2)
- .reshape(weights.shape))
+ .swapaxes(1, 2)
+ .reshape(weights.shape))
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
@@ -325,7 +375,7 @@ class Tensor(metaclass=ABCMeta):
@abstractmethod
def astype(self, data_type: DataType) -> 'Tensor': ...
@abstractmethod
- def permute(self, n_head: int) -> 'Tensor': ...
+ def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
@abstractmethod
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
@abstractmethod
@@ -363,8 +413,8 @@ class UnquantizedTensor(Tensor):
r = self.ndarray.shape[0] // 3
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
- def permute(self, n_head: int) -> 'UnquantizedTensor':
- return UnquantizedTensor(permute(self.ndarray, n_head))
+ def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
+ return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
@@ -412,26 +462,34 @@ class GGMLQuantizedTensor(Tensor):
def to_ggml(self) -> 'GGMLQuantizedTensor':
return self
- def permute(self, n_head: int) -> 'GGMLQuantizedTensor':
- return GGMLQuantizedTensor(permute(self.ndarray, n_head), self.shape, self.data_type)
+ def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'GGMLQuantizedTensor':
+ return GGMLQuantizedTensor(permute(self.ndarray, n_head, n_kv_head), self.shape, self.data_type)
+ def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
+ r = self.ndarray.shape[0] // 3
+ return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head))
+
+ def part(self, n_part: int) -> 'UnquantizedTensor':
+ r = self.ndarray.shape[0] // 3
+ return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
GGMLCompatibleTensor = Union[UnquantizedTensor, GGMLQuantizedTensor]
class DeferredPermutedTensor(Tensor):
- def __init__(self, base: Tensor, n_head: int) -> None:
+ def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
self.base = base
self.n_head = n_head
+ self.n_kv_head = n_kv_head
self.data_type = self.base.data_type
def astype(self, data_type: DataType) -> Tensor:
- return self.base.astype(data_type).permute(self.n_head)
+ return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
def to_ggml(self) -> GGMLCompatibleTensor:
- return self.base.to_ggml().permute(self.n_head)
+ return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
- def permute(self, n_head: int) -> Tensor:
+ def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
raise Exception("shouldn't permute twice")
@@ -523,8 +581,8 @@ class GPTQForLLaMaQuantizedTensor(Tensor):
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False)
return ret
- def permute(self, n_head: int) -> Tensor:
- return DeferredPermutedTensor(self, n_head)
+ def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
+ return DeferredPermutedTensor(self, n_head, n_kv_head)
def to_ggml(self) -> GGMLQuantizedTensor:
# The output format looks like this:
@@ -655,10 +713,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
return ModelPlus(model, paths, format, vocab)
-def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
+def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
def load() -> Tensor:
- return lazy_tensor.load().permute(n_head)
- return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
+ return lazy_tensor.load().permute(n_head, n_kv_head)
+ return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
def load() -> Tensor:
@@ -683,7 +741,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
- out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
+ out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
@@ -1035,8 +1093,7 @@ class OutputFile:
@staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out)
- params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
- n_head=1, n_layer=0)
+ params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab)
@@ -1171,14 +1228,18 @@ def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}
-def load_vocab(path: Path) -> SentencePieceVocab:
+def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
+ print(f"vocabtype: {vocabtype}")
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
if path.is_dir():
- path2 = path / "tokenizer.model"
+ vocab_file = "tokenizer.model"
+ if vocabtype == 'bpe':
+ vocab_file = "vocab.json"
+ path2 = path / vocab_file
# Use `.parent` instead of /.. to handle the symlink case better.
- path3 = path.parent / "tokenizer.model"
+ path3 = path.parent / vocab_file
if path2.exists():
path = path2
elif path3.exists():
@@ -1189,7 +1250,8 @@ def load_vocab(path: Path) -> SentencePieceVocab:
"if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
- return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
+ return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
+ vocabtype)
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
@@ -1227,6 +1289,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path,
help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
+ parser.add_argument("--vocabtype", default='spm', choices=["spm", "bpe"], help="vocab format (default: spm)")
args = parser.parse_args(args_in)
vocab: Vocab
@@ -1234,7 +1297,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
model_plus = lazy_load_file(args.model)
do_dump_model(model_plus)
elif args.vocab_only:
- vocab = load_vocab(args.vocab_dir or args.model)
+ vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
assert args.outfile, "need --outfile if using --vocab-only"
outfile = args.outfile
OutputFile.write_vocab_only(outfile, vocab)
@@ -1248,7 +1311,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab = model_plus.vocab
else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
- vocab = load_vocab(vocab_dir)
+ vocab = load_vocab(vocab_dir, args.vocabtype)
params = Params.load(model_plus)
model = model_plus.model
model = do_necessary_conversions(model, params)