|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from dataclasses import dataclass
|
4 |
| -from typing import Literal, Any |
| 4 | +from typing import Literal |
5 | 5 |
|
6 | 6 | import os
|
7 | 7 | import json
|
8 |
| -import requests |
9 |
| -import threading |
10 |
| -import logging |
11 |
| -from urllib.parse import urlparse |
12 | 8 |
|
13 | 9 |
|
14 | 10 | def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
@@ -111,15 +107,9 @@ class SafetensorRemote:
|
111 | 107 | print(data)
|
112 | 108 | """
|
113 | 109 |
|
114 |
| - logger = logging.getLogger("safetensor_remote") |
115 |
| - |
116 | 110 | BASE_DOMAIN = "https://huggingface.co"
|
117 | 111 | ALIGNMENT = 8 # bytes
|
118 | 112 |
|
119 |
| - # start using multithread download for files larger than 100MB |
120 |
| - MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 |
121 |
| - MULTITHREAD_COUNT = 8 # number of threads |
122 |
| - |
123 | 113 | @classmethod
|
124 | 114 | def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
125 | 115 | """
|
@@ -221,153 +211,47 @@ def get_metadata(cls, url: str) -> tuple[dict, int]:
|
221 | 211 | except json.JSONDecodeError as e:
|
222 | 212 | raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
223 | 213 |
|
224 |
| - @classmethod |
225 |
| - def _get_request_headers(cls) -> dict[str, str]: |
226 |
| - """Prepare common headers for requests.""" |
227 |
| - headers = {"User-Agent": "convert_hf_to_gguf"} |
228 |
| - if os.environ.get("HF_TOKEN"): |
229 |
| - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
230 |
| - return headers |
231 |
| - |
232 | 214 | @classmethod
|
233 | 215 | def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
234 | 216 | """
|
235 |
| - Get raw byte data from a remote file by range using single or multi-threaded download. |
236 |
| -
|
237 |
| - If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). |
238 |
| - If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. |
239 |
| - Otherwise, it uses a single request. |
| 217 | + Get raw byte data from a remote file by range. |
| 218 | + If size is not specified, it will read the entire file. |
240 | 219 | """
|
| 220 | + import requests |
| 221 | + from urllib.parse import urlparse |
| 222 | + |
241 | 223 | parsed_url = urlparse(url)
|
242 | 224 | if not parsed_url.scheme or not parsed_url.netloc:
|
243 | 225 | raise ValueError(f"Invalid URL: {url}")
|
244 | 226 |
|
245 |
| - common_headers = cls._get_request_headers() |
246 |
| - |
247 |
| - # --- Multithreading Path --- |
248 |
| - if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: |
249 |
| - cls.logger.info(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") |
250 |
| - num_threads = cls.MULTITHREAD_COUNT |
251 |
| - results: list[Any] = [None] * num_threads # Store results or exceptions |
252 |
| - threads: list[threading.Thread] = [] |
253 |
| - |
254 |
| - def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): |
255 |
| - """Worker function for thread.""" |
256 |
| - thread_headers = headers.copy() |
257 |
| - # Range header is inclusive end byte |
258 |
| - range_end = chunk_start + chunk_size - 1 |
259 |
| - thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" |
260 |
| - try: |
261 |
| - # Using stream=False should make requests wait for content download |
262 |
| - response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout |
263 |
| - response.raise_for_status() # Check for HTTP errors |
264 |
| - |
265 |
| - content = response.content |
266 |
| - if len(content) != chunk_size: |
267 |
| - # This is a critical check |
268 |
| - raise IOError( |
269 |
| - f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " |
270 |
| - f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" |
271 |
| - ) |
272 |
| - result_list[index] = content |
273 |
| - except Exception as e: |
274 |
| - # Store exception to be raised by the main thread |
275 |
| - # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print |
276 |
| - result_list[index] = e |
277 |
| - |
278 |
| - # Calculate chunk sizes and create/start threads |
279 |
| - base_chunk_size = size // num_threads |
280 |
| - remainder = size % num_threads |
281 |
| - current_offset = start |
282 |
| - |
283 |
| - for i in range(num_threads): |
284 |
| - chunk_size = base_chunk_size + (1 if i < remainder else 0) |
285 |
| - if chunk_size == 0: # Should not happen if size >= threshold but handle defensively |
286 |
| - results[i] = b"" # Store empty bytes for this "chunk" |
287 |
| - continue |
288 |
| - |
289 |
| - thread = threading.Thread( |
290 |
| - target=download_chunk, |
291 |
| - args=(url, current_offset, chunk_size, i, results, common_headers), |
292 |
| - daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) |
293 |
| - ) |
294 |
| - threads.append(thread) |
295 |
| - thread.start() |
296 |
| - current_offset += chunk_size # Move offset for the next chunk |
297 |
| - |
298 |
| - # Wait for all threads to complete |
299 |
| - for i, thread in enumerate(threads): |
300 |
| - thread.join() # Wait indefinitely for each thread |
301 |
| - |
302 |
| - # Check results for errors and concatenate chunks |
303 |
| - final_data_parts = [] |
304 |
| - for i in range(num_threads): |
305 |
| - result = results[i] |
306 |
| - if isinstance(result, Exception): |
307 |
| - # Raise the first exception encountered |
308 |
| - raise result |
309 |
| - elif result is None: |
310 |
| - # This indicates a thread finished without setting its result or exception (unexpected) |
311 |
| - # Check if it was supposed to download anything |
312 |
| - expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) |
313 |
| - if expected_chunk_size > 0: |
314 |
| - raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") |
315 |
| - else: |
316 |
| - final_data_parts.append(b"") # Append empty bytes for zero-size chunk |
317 |
| - else: |
318 |
| - final_data_parts.append(result) |
319 |
| - |
320 |
| - # Combine the byte chunks |
321 |
| - final_data = b"".join(final_data_parts) |
322 |
| - |
323 |
| - # Final validation: Does the combined size match the requested size? |
324 |
| - if len(final_data) != size: |
325 |
| - raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") |
326 |
| - |
327 |
| - return final_data |
328 |
| - |
329 |
| - # --- Single-threaded Path --- |
330 |
| - else: |
331 |
| - # print(f"Using single thread for size {size}") # Optional debug print |
332 |
| - headers = common_headers.copy() |
333 |
| - if size > -1: |
334 |
| - # Range header uses inclusive end byte |
335 |
| - range_end = start + size - 1 |
336 |
| - headers["Range"] = f"bytes={start}-{range_end}" |
337 |
| - elif start > 0: |
338 |
| - # Request from start offset to the end of the file |
339 |
| - headers["Range"] = f"bytes={start}-" |
340 |
| - # If start=0 and size=-1, no Range header is needed (get full file) |
341 |
| - |
342 |
| - response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout |
343 |
| - response.raise_for_status() |
344 |
| - content = response.content |
345 |
| - |
346 |
| - # Validate downloaded size if a specific size was requested |
347 |
| - if size > -1 and len(content) != size: |
348 |
| - # Check status code - 206 Partial Content is expected for successful range requests |
349 |
| - status_code = response.status_code |
350 |
| - content_range = response.headers.get('Content-Range') |
351 |
| - raise IOError( |
352 |
| - f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " |
353 |
| - f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" |
354 |
| - ) |
355 |
| - |
356 |
| - return content |
| 227 | + headers = {} |
| 228 | + if os.environ.get("HF_TOKEN"): |
| 229 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
| 230 | + if size > -1: |
| 231 | + headers["Range"] = f"bytes={start}-{start + size}" |
| 232 | + response = requests.get(url, allow_redirects=True, headers=headers) |
| 233 | + response.raise_for_status() |
| 234 | + |
| 235 | + # Get raw byte data |
| 236 | + return response.content[:size] |
357 | 237 |
|
358 | 238 | @classmethod
|
359 | 239 | def check_file_exist(cls, url: str) -> bool:
|
360 | 240 | """
|
361 | 241 | Check if a file exists at the given URL.
|
362 | 242 | Returns True if the file exists, False otherwise.
|
363 | 243 | """
|
| 244 | + import requests |
| 245 | + from urllib.parse import urlparse |
| 246 | + |
364 | 247 | parsed_url = urlparse(url)
|
365 | 248 | if not parsed_url.scheme or not parsed_url.netloc:
|
366 | 249 | raise ValueError(f"Invalid URL: {url}")
|
367 | 250 |
|
368 | 251 | try:
|
369 |
| - headers = cls._get_request_headers() |
370 |
| - headers["Range"] = "bytes=0-0" # Request a small range to check existence |
| 252 | + headers = {"Range": "bytes=0-0"} |
| 253 | + if os.environ.get("HF_TOKEN"): |
| 254 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
371 | 255 | response = requests.head(url, allow_redirects=True, headers=headers)
|
372 | 256 | # Success (2xx) or redirect (3xx)
|
373 | 257 | return 200 <= response.status_code < 400
|
|
0 commit comments