4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import os
7
+ import shutil
7
8
import urllib .request
8
9
from pathlib import Path
9
10
from typing import Optional , Sequence
19
20
20
21
21
22
def _download_hf_snapshot (
22
- model_config : ModelConfig , models_dir : Path , hf_token : Optional [str ]
23
+ model_config : ModelConfig , artifact_dir : Path , hf_token : Optional [str ]
23
24
):
24
- model_dir = models_dir / model_config .name
25
- os .makedirs (model_dir , exist_ok = True )
26
-
27
25
from huggingface_hub import snapshot_download
28
26
29
27
# Download and store the HF model artifacts.
30
28
print (f"Downloading { model_config .name } from HuggingFace..." )
31
29
try :
32
30
snapshot_download (
33
31
model_config .distribution_path ,
34
- local_dir = model_dir ,
32
+ local_dir = artifact_dir ,
35
33
local_dir_use_symlinks = False ,
36
34
token = hf_token ,
37
35
ignore_patterns = "*safetensors*" ,
38
36
)
39
37
except HTTPError as e :
40
38
if e .response .status_code == 401 :
41
- os .rmdir (model_dir )
42
39
raise RuntimeError (
43
40
"Access denied. Run huggingface-cli login to authenticate."
44
41
)
@@ -48,20 +45,16 @@ def _download_hf_snapshot(
48
45
49
46
# Convert the model to the torchchat format.
50
47
print (f"Converting { model_config .name } to torchchat format..." )
51
- convert_hf_checkpoint (model_dir = model_dir , model_name = model_config .name , remove_bin_files = True )
48
+ convert_hf_checkpoint (model_dir = artifact_dir , model_name = model_config .name , remove_bin_files = True )
52
49
53
50
54
51
def _download_direct (
55
52
model_config : ModelConfig ,
56
- urls : Sequence [str ],
57
- models_dir : Path ,
53
+ artifact_dir : Path ,
58
54
):
59
- model_dir = models_dir / model_config .name
60
- os .makedirs (model_dir , exist_ok = True )
61
-
62
- for url in urls :
55
+ for url in model_config .distribution_path :
63
56
filename = url .split ("/" )[- 1 ]
64
- local_path = model_dir / filename
57
+ local_path = artifact_dir / filename
65
58
print (f"Downloading { url } ..." )
66
59
urllib .request .urlretrieve (url , str (local_path .absolute ()))
67
60
@@ -70,18 +63,38 @@ def download_and_convert(
70
63
model : str , models_dir : Path , hf_token : Optional [str ] = None
71
64
) -> None :
72
65
model_config = resolve_model_config (model )
66
+ model_dir = models_dir / model_config .name
73
67
74
- if (
75
- model_config .distribution_channel
76
- == ModelDistributionChannel .HuggingFaceSnapshot
77
- ):
78
- _download_hf_snapshot (model_config , models_dir , hf_token )
79
- elif model_config .distribution_channel == ModelDistributionChannel .DirectDownload :
80
- _download_direct (model_config , model_config .distribution_path , models_dir )
81
- else :
82
- raise RuntimeError (
83
- f"Unknown distribution channel { model_config .distribution_channel } ."
84
- )
68
+ # Download into a temporary directory. We'll move to the final location once
69
+ # the download and conversion is complete. This allows recovery in the event
70
+ # that the download or conversion fails unexpectedly.
71
+ temp_dir = models_dir / "downloads" / model_config .name
72
+ if os .path .isdir (temp_dir ):
73
+ shutil .rmtree (temp_dir )
74
+ os .makedirs (temp_dir , exist_ok = True )
75
+
76
+ try :
77
+ if (
78
+ model_config .distribution_channel
79
+ == ModelDistributionChannel .HuggingFaceSnapshot
80
+ ):
81
+ _download_hf_snapshot (model_config , temp_dir , hf_token )
82
+ elif model_config .distribution_channel == ModelDistributionChannel .DirectDownload :
83
+ _download_direct (model_config , temp_dir )
84
+ else :
85
+ raise RuntimeError (
86
+ f"Unknown distribution channel { model_config .distribution_channel } ."
87
+ )
88
+
89
+ # Move from the temporary directory to the intended location,
90
+ # overwriting if necessary.
91
+ if os .path .isdir (model_dir ):
92
+ shutil .rmtree (model_dir )
93
+ os .rename (temp_dir , model_dir )
94
+
95
+ finally :
96
+ if os .path .isdir (temp_dir ):
97
+ shutil .rmtree (temp_dir )
85
98
86
99
87
100
def is_model_downloaded (model : str , models_dir : Path ) -> bool :
0 commit comments