Skip to content

Commit 8fc747a

Browse files
GregoryComermalfet
authored andcommitted
Recover from aborted or failed model downloads (#358)
1 parent 1e2aa07 commit 8fc747a

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

download.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import os
7+
import shutil
78
import urllib.request
89
from pathlib import Path
910
from typing import Optional, Sequence
@@ -19,26 +20,22 @@
1920

2021

2122
def _download_hf_snapshot(
22-
model_config: ModelConfig, models_dir: Path, hf_token: Optional[str]
23+
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2324
):
24-
model_dir = models_dir / model_config.name
25-
os.makedirs(model_dir, exist_ok=True)
26-
2725
from huggingface_hub import snapshot_download
2826

2927
# Download and store the HF model artifacts.
3028
print(f"Downloading {model_config.name} from HuggingFace...")
3129
try:
3230
snapshot_download(
3331
model_config.distribution_path,
34-
local_dir=model_dir,
32+
local_dir=artifact_dir,
3533
local_dir_use_symlinks=False,
3634
token=hf_token,
3735
ignore_patterns="*safetensors*",
3836
)
3937
except HTTPError as e:
4038
if e.response.status_code == 401:
41-
os.rmdir(model_dir)
4239
raise RuntimeError(
4340
"Access denied. Run huggingface-cli login to authenticate."
4441
)
@@ -48,20 +45,16 @@ def _download_hf_snapshot(
4845

4946
# Convert the model to the torchchat format.
5047
print(f"Converting {model_config.name} to torchchat format...")
51-
convert_hf_checkpoint(model_dir=model_dir, model_name=model_config.name, remove_bin_files=True)
48+
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)
5249

5350

5451
def _download_direct(
5552
model_config: ModelConfig,
56-
urls: Sequence[str],
57-
models_dir: Path,
53+
artifact_dir: Path,
5854
):
59-
model_dir = models_dir / model_config.name
60-
os.makedirs(model_dir, exist_ok=True)
61-
62-
for url in urls:
55+
for url in model_config.distribution_path:
6356
filename = url.split("/")[-1]
64-
local_path = model_dir / filename
57+
local_path = artifact_dir / filename
6558
print(f"Downloading {url}...")
6659
urllib.request.urlretrieve(url, str(local_path.absolute()))
6760

@@ -70,18 +63,38 @@ def download_and_convert(
7063
model: str, models_dir: Path, hf_token: Optional[str] = None
7164
) -> None:
7265
model_config = resolve_model_config(model)
66+
model_dir = models_dir / model_config.name
7367

74-
if (
75-
model_config.distribution_channel
76-
== ModelDistributionChannel.HuggingFaceSnapshot
77-
):
78-
_download_hf_snapshot(model_config, models_dir, hf_token)
79-
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
80-
_download_direct(model_config, model_config.distribution_path, models_dir)
81-
else:
82-
raise RuntimeError(
83-
f"Unknown distribution channel {model_config.distribution_channel}."
84-
)
68+
# Download into a temporary directory. We'll move to the final location once
69+
# the download and conversion is complete. This allows recovery in the event
70+
# that the download or conversion fails unexpectedly.
71+
temp_dir = models_dir / "downloads" / model_config.name
72+
if os.path.isdir(temp_dir):
73+
shutil.rmtree(temp_dir)
74+
os.makedirs(temp_dir, exist_ok=True)
75+
76+
try:
77+
if (
78+
model_config.distribution_channel
79+
== ModelDistributionChannel.HuggingFaceSnapshot
80+
):
81+
_download_hf_snapshot(model_config, temp_dir, hf_token)
82+
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
83+
_download_direct(model_config, temp_dir)
84+
else:
85+
raise RuntimeError(
86+
f"Unknown distribution channel {model_config.distribution_channel}."
87+
)
88+
89+
# Move from the temporary directory to the intended location,
90+
# overwriting if necessary.
91+
if os.path.isdir(model_dir):
92+
shutil.rmtree(model_dir)
93+
os.rename(temp_dir, model_dir)
94+
95+
finally:
96+
if os.path.isdir(temp_dir):
97+
shutil.rmtree(temp_dir)
8598

8699

87100
def is_model_downloaded(model: str, models_dir: Path) -> bool:

0 commit comments

Comments
 (0)