Skip to content

Recover from aborted or failed model downloads #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import urllib.request
from pathlib import Path
from typing import Optional, Sequence
Expand All @@ -19,26 +20,22 @@


def _download_hf_snapshot(
model_config: ModelConfig, models_dir: Path, hf_token: Optional[str]
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)

from huggingface_hub import snapshot_download

# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...")
try:
snapshot_download(
model_config.distribution_path,
local_dir=model_dir,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401:
os.rmdir(model_dir)
raise RuntimeError(
"Access denied. Run huggingface-cli login to authenticate."
)
Expand All @@ -48,20 +45,16 @@ def _download_hf_snapshot(

# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...")
convert_hf_checkpoint(model_dir=model_dir, model_name=model_config.name, remove_bin_files=True)
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)


def _download_direct(
model_config: ModelConfig,
urls: Sequence[str],
models_dir: Path,
artifact_dir: Path,
):
model_dir = models_dir / model_config.name
os.makedirs(model_dir, exist_ok=True)

for url in urls:
for url in model_config.distribution_path:
filename = url.split("/")[-1]
local_path = model_dir / filename
local_path = artifact_dir / filename
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))

Expand All @@ -70,18 +63,38 @@ def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name

if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, models_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, model_config.distribution_path, models_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)
# Download into a temporary directory. We'll move to the final location once
# the download and conversion is complete. This allows recovery in the event
# that the download or conversion fails unexpectedly.
temp_dir = models_dir / "downloads" / model_config.name
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)

try:
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, temp_dir, hf_token)
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, temp_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)

# Move from the temporary directory to the intended location,
# overwriting if necessary.
if os.path.isdir(model_dir):
shutil.rmtree(model_dir)
os.rename(temp_dir, model_dir)

finally:
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)


def is_model_downloaded(model: str, models_dir: Path) -> bool:
Expand Down