|
22 | 22 | def _download_hf_snapshot(
|
23 | 23 | model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
|
24 | 24 | ):
|
25 |
| - from huggingface_hub import snapshot_download |
| 25 | + from huggingface_hub import model_info, snapshot_download |
26 | 26 | from requests.exceptions import HTTPError
|
27 | 27 |
|
28 | 28 | # Download and store the HF model artifacts.
|
29 | 29 | print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
|
30 | 30 | try:
|
| 31 | + # Fetch the info about the model's repo |
| 32 | + model_info = model_info(model_config.distribution_path, token=hf_token) |
| 33 | + model_fnames = [f.rfilename for f in model_info.siblings] |
| 34 | + |
| 35 | + # Check the model config for preference between safetensors and pth |
| 36 | + has_pth = any(f.endswith(".pth") for f in model_fnames) |
| 37 | + has_safetensors = any(f.endswith(".safetensors") for f in model_fnames) |
| 38 | + |
| 39 | + # If told to prefer safetensors, ignore pth files |
| 40 | + if model_config.prefer_safetensors: |
| 41 | + if not has_safetensors: |
| 42 | + print( |
| 43 | + f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.", |
| 44 | + file=sys.stderr, |
| 45 | + ) |
| 46 | + exit(1) |
| 47 | + ignore_patterns = "*.pth" |
| 48 | + |
| 49 | + # If the model has both, prefer pth files over safetensors |
| 50 | + elif has_pth and has_safetensors: |
| 51 | + ignore_patterns = "*safetensors*" |
| 52 | + |
| 53 | + # Otherwise, download everything |
| 54 | + else: |
| 55 | + ignore_patterns = None |
| 56 | + |
31 | 57 | snapshot_download(
|
32 | 58 | model_config.distribution_path,
|
33 | 59 | local_dir=artifact_dir,
|
34 | 60 | local_dir_use_symlinks=False,
|
35 | 61 | token=hf_token,
|
36 |
| - ignore_patterns="*safetensors*", |
| 62 | + ignore_patterns=ignore_patterns, |
37 | 63 | )
|
38 | 64 | except HTTPError as e:
|
39 | 65 | if e.response.status_code == 401: # Missing HuggingFace CLI login.
|
|
0 commit comments