aboutsummaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert.py')
-rw-r--r--convert.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/convert.py b/convert.py
index 126beaa..8f4f039 100644
--- a/convert.py
+++ b/convert.py
@@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
return LazyTensor(load, shape, data_type, description)
- model = {name: convert(info) for (name, info) in header.items()}
+ model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
@@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus:
'''Load a model of any supported format.'''
# Be extra-friendly and accept either a file or a directory:
if path.is_dir():
- globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
- files = [file for glob in globs for file in path.glob(glob)]
+ # Check if it's a set of safetensors files first
+ files = list(path.glob("model-00001-of-*.safetensors"))
+ if not files:
+ # Try the PyTorch patterns too, with lower priority
+ globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
+ files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try GGML too, but with lower priority, since if both a non-GGML
# model and a GGML model exist in the same directory, we assume the