diff options
author | ubik2 <ubik2@users.noreply.github.com> | 2023-05-08 04:54:26 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-08 13:54:26 +0200 |
commit | 95078cc554fe03d4512363c7e4dec963f0047c72 (patch) | |
tree | 2568ac59f2a987f5a2dce0ce96c0c014d474c255 /convert.py | |
parent | 1f48b0abcfbd6cc99571e42348e0ec97e4be8b93 (diff) |
convert: add ability to convert safetensors files (#1276)
* when loading a safetensors file, ignore the metadata header
* check for safetensors files first, and only use PyTorch versions when safetensors aren't available
Diffstat (limited to 'convert.py')
-rw-r--r-- | convert.py | 10 |
1 files changed, 7 insertions, 3 deletions
@@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) description = f'safetensors begin={begin} end={end} type={data_type} path={path}' return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items()} + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) @@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] - files = [file for glob in globs for file in path.glob(glob)] + # Check if it's a set of safetensors files first + files = list(path.glob("model-00001-of-*.safetensors")) + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML # model and a GGML model exist in the same directory, we assume the |