aboutsummaryrefslogtreecommitdiff
path: root/download-pth.py
blob: 129532c0c6b403dcb361769720b668af43e7f9d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import sys
from tqdm import tqdm
import requests

if len(sys.argv) < 3:
    print("Usage: download-pth.py dir-model model-type\n")
    print("  model-type: Available models 7B, 13B, 30B or 65B")
    sys.exit(1)

modelsDir = sys.argv[1]
model = sys.argv[2]

num = {
    "7B": 1,
    "13B": 2,
    "30B": 4,
    "65B": 8,
}

if model not in num:
    print(f"Error: model {model} is not valid, provide 7B, 13B, 30B or 65B")
    sys.exit(1)

print(f"Downloading model {model}")

files = ["checklist.chk", "params.json"]

for i in range(num[model]):
    files.append(f"consolidated.0{i}.pth")

resolved_path = os.path.abspath(os.path.join(modelsDir, model))
os.makedirs(resolved_path, exist_ok=True)

for file in files:
    dest_path = os.path.join(resolved_path, file)
    
    if os.path.exists(dest_path):
        print(f"Skip file download, it already exists: {file}")
        continue

    url = f"https://agi.gpt4.org/llama/LLaMA/{model}/{file}"
    response = requests.get(url, stream=True)
    with open(dest_path, 'wb') as f:
        with tqdm(unit='B', unit_scale=True, miniters=1, desc=file) as t:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
                    t.update(len(chunk))

files2 = ["tokenizer_checklist.chk", "tokenizer.model"]
for file in files2:
    dest_path = os.path.join(modelsDir, file)
    
    if os.path.exists(dest_path):
        print(f"Skip file download, it already exists: {file}")
        continue
    
    url = f"https://agi.gpt4.org/llama/LLaMA/{file}"
    response = requests.get(url, stream=True)
    with open(dest_path, 'wb') as f:
        with tqdm(unit='B', unit_scale=True, miniters=1, desc=file) as t:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
                    t.update(len(chunk))