|
2 | 2 |
|
3 | 3 | from typing import Literal
|
4 | 4 |
|
| 5 | +import json |
| 6 | + |
5 | 7 |
|
6 | 8 | def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
7 | 9 | # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
@@ -67,3 +69,172 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
67 | 69 | kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
68 | 70 |
|
69 | 71 | return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
| 72 | + |
| 73 | +class SafetensorRemote: |
| 74 | + """ |
| 75 | + Uility class to handle remote safetensor files. |
| 76 | + This class is designed to work with Hugging Face model repositories. |
| 77 | +
|
| 78 | + Example (one model has single safetensor file, the other has multiple): |
| 79 | + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: |
| 80 | + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) |
| 81 | + print(json.dumps(tensors, indent=2)) |
| 82 | +
|
| 83 | + Example reading tensor data: |
| 84 | + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) |
| 85 | + for name, meta in tensors.items(): |
| 86 | + dtype, shape, offset_start, size, remote_safetensor_url = meta |
| 87 | + # read the tensor data |
| 88 | + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) |
| 89 | + print(data) |
| 90 | + """ |
| 91 | + |
| 92 | + BASE_DOMAIN = "https://huggingface.co" |
| 93 | + ALIGNMENT = 8 # bytes |
| 94 | + |
| 95 | + @classmethod |
| 96 | + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]: |
| 97 | + """ |
| 98 | + Get list of tensors from a Hugging Face model repository. |
| 99 | +
|
| 100 | + Returns a dictionary of tensor names and their metadata. |
| 101 | + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) |
| 102 | + """ |
| 103 | + # case 1: model has only one single model.safetensor file |
| 104 | + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") |
| 105 | + if is_single_file: |
| 106 | + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" |
| 107 | + tensors: dict[str, tuple[str, list[int], int, int, str]] = {} |
| 108 | + for key, val in cls.get_list_tensors(url).items(): |
| 109 | + tensors[key] = (*val, url) # populate the url |
| 110 | + return tensors |
| 111 | + |
| 112 | + # case 2: model has multiple files |
| 113 | + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" |
| 114 | + is_multiple_files = cls.check_file_exist(index_url) |
| 115 | + if is_multiple_files: |
| 116 | + # read the index file |
| 117 | + index_data = cls.get_data_by_range(index_url, 0) |
| 118 | + index_str = index_data.decode('utf-8') |
| 119 | + index_json = json.loads(index_str) |
| 120 | + assert index_json.get("weight_map") is not None, "weight_map not found in index file" |
| 121 | + weight_map = index_json["weight_map"] |
| 122 | + # get the list of files |
| 123 | + all_files = list(set(weight_map.values())) |
| 124 | + all_files.sort() # make sure we load shard files in order |
| 125 | + # get the list of tensors |
| 126 | + tensors = {} |
| 127 | + for file in all_files: |
| 128 | + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" |
| 129 | + for key, val in cls.get_list_tensors(url).items(): |
| 130 | + tensors[key] = (*val, url) # populate the url |
| 131 | + return tensors |
| 132 | + |
| 133 | + raise ValueError(f"Model {model_id} does not have any safetensor files") |
| 134 | + |
| 135 | + @classmethod |
| 136 | + def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]: |
| 137 | + """ |
| 138 | + Get list of tensors from a remote safetensor file. |
| 139 | +
|
| 140 | + Returns a dictionary of tensor names and their metadata. |
| 141 | + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) |
| 142 | + """ |
| 143 | + metadata, data_start_offset = cls.get_metadata(url) |
| 144 | + res: dict[str, tuple[str, list[int], int, int]] = {} |
| 145 | + |
| 146 | + for name, meta in metadata.items(): |
| 147 | + if name == "__metadata__": |
| 148 | + continue |
| 149 | + if not isinstance(meta, dict): |
| 150 | + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") |
| 151 | + try: |
| 152 | + dtype = meta["dtype"] |
| 153 | + shape = meta["shape"] |
| 154 | + offset_start_relative, offset_end_relative = meta["data_offsets"] |
| 155 | + size = offset_end_relative - offset_start_relative |
| 156 | + offset_start = data_start_offset + offset_start_relative |
| 157 | + res[name] = (dtype, shape, offset_start, size) |
| 158 | + except KeyError as e: |
| 159 | + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") |
| 160 | + |
| 161 | + return res |
| 162 | + |
| 163 | + @classmethod |
| 164 | + def get_metadata(cls, url: str) -> tuple[dict, int]: |
| 165 | + """ |
| 166 | + Get JSON metadata from a remote safetensor file. |
| 167 | +
|
| 168 | + Returns tuple of (metadata, data_start_offset) |
| 169 | + """ |
| 170 | + # Request first 5MB of the file (hopefully enough for metadata) |
| 171 | + read_size = 5 * 1024 * 1024 |
| 172 | + raw_data = cls.get_data_by_range(url, 0, read_size) |
| 173 | + |
| 174 | + # Parse header |
| 175 | + # First 8 bytes contain the metadata length as u64 little-endian |
| 176 | + if len(raw_data) < 8: |
| 177 | + raise ValueError("Not enough data to read metadata size") |
| 178 | + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') |
| 179 | + |
| 180 | + # Calculate the data start offset |
| 181 | + data_start_offset = 8 + metadata_length |
| 182 | + alignment = SafetensorRemote.ALIGNMENT |
| 183 | + if data_start_offset % alignment != 0: |
| 184 | + data_start_offset += alignment - (data_start_offset % alignment) |
| 185 | + |
| 186 | + # Check if we have enough data to read the metadata |
| 187 | + if len(raw_data) < 8 + metadata_length: |
| 188 | + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") |
| 189 | + |
| 190 | + # Extract metadata bytes and parse as JSON |
| 191 | + metadata_bytes = raw_data[8:8 + metadata_length] |
| 192 | + metadata_str = metadata_bytes.decode('utf-8') |
| 193 | + try: |
| 194 | + metadata = json.loads(metadata_str) |
| 195 | + return metadata, data_start_offset |
| 196 | + except json.JSONDecodeError as e: |
| 197 | + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") |
| 198 | + |
| 199 | + @classmethod |
| 200 | + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: |
| 201 | + """ |
| 202 | + Get raw byte data from a remote file by range. |
| 203 | + If size is not specified, it will read the entire file. |
| 204 | + """ |
| 205 | + import requests |
| 206 | + from urllib.parse import urlparse |
| 207 | + |
| 208 | + parsed_url = urlparse(url) |
| 209 | + if not parsed_url.scheme or not parsed_url.netloc: |
| 210 | + raise ValueError(f"Invalid URL: {url}") |
| 211 | + |
| 212 | + headers = {} |
| 213 | + if size > -1: |
| 214 | + headers = {"Range": f"bytes={start}-{start + size}"} |
| 215 | + response = requests.get(url, allow_redirects=True, headers=headers) |
| 216 | + response.raise_for_status() |
| 217 | + |
| 218 | + # Get raw byte data |
| 219 | + return response.content[:size] |
| 220 | + |
| 221 | + @classmethod |
| 222 | + def check_file_exist(cls, url: str) -> bool: |
| 223 | + """ |
| 224 | + Check if a file exists at the given URL. |
| 225 | + Returns True if the file exists, False otherwise. |
| 226 | + """ |
| 227 | + import requests |
| 228 | + from urllib.parse import urlparse |
| 229 | + |
| 230 | + parsed_url = urlparse(url) |
| 231 | + if not parsed_url.scheme or not parsed_url.netloc: |
| 232 | + raise ValueError(f"Invalid URL: {url}") |
| 233 | + |
| 234 | + try: |
| 235 | + headers = {"Range": f"bytes=0-0"} |
| 236 | + response = requests.head(url, allow_redirects=True, headers=headers) |
| 237 | + # Success (2xx) or redirect (3xx) |
| 238 | + return 200 <= response.status_code < 400 |
| 239 | + except requests.RequestException: |
| 240 | + return False |
0 commit comments