Skip to content

Commit 2507c8e

Browse files
committed
gguf util : add SafetensorRemote
1 parent 656babd commit 2507c8e

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

gguf-py/gguf/utility.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import Literal
44

5+
import json
6+
57

68
def fill_templated_filename(filename: str, output_type: str | None) -> str:
79
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@@ -67,3 +69,172 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
6769
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
6870

6971
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
72+
73+
class SafetensorRemote:
74+
"""
75+
Uility class to handle remote safetensor files.
76+
This class is designed to work with Hugging Face model repositories.
77+
78+
Example (one model has single safetensor file, the other has multiple):
79+
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
80+
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
81+
print(json.dumps(tensors, indent=2))
82+
83+
Example reading tensor data:
84+
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
85+
for name, meta in tensors.items():
86+
dtype, shape, offset_start, size, remote_safetensor_url = meta
87+
# read the tensor data
88+
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
89+
print(data)
90+
"""
91+
92+
BASE_DOMAIN = "https://huggingface.co"
93+
ALIGNMENT = 8 # bytes
94+
95+
@classmethod
96+
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]:
97+
"""
98+
Get list of tensors from a Hugging Face model repository.
99+
100+
Returns a dictionary of tensor names and their metadata.
101+
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
102+
"""
103+
# case 1: model has only one single model.safetensor file
104+
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
105+
if is_single_file:
106+
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
107+
tensors: dict[str, tuple[str, list[int], int, int, str]] = {}
108+
for key, val in cls.get_list_tensors(url).items():
109+
tensors[key] = (*val, url) # populate the url
110+
return tensors
111+
112+
# case 2: model has multiple files
113+
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
114+
is_multiple_files = cls.check_file_exist(index_url)
115+
if is_multiple_files:
116+
# read the index file
117+
index_data = cls.get_data_by_range(index_url, 0)
118+
index_str = index_data.decode('utf-8')
119+
index_json = json.loads(index_str)
120+
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
121+
weight_map = index_json["weight_map"]
122+
# get the list of files
123+
all_files = list(set(weight_map.values()))
124+
all_files.sort() # make sure we load shard files in order
125+
# get the list of tensors
126+
tensors = {}
127+
for file in all_files:
128+
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
129+
for key, val in cls.get_list_tensors(url).items():
130+
tensors[key] = (*val, url) # populate the url
131+
return tensors
132+
133+
raise ValueError(f"Model {model_id} does not have any safetensor files")
134+
135+
@classmethod
136+
def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]:
137+
"""
138+
Get list of tensors from a remote safetensor file.
139+
140+
Returns a dictionary of tensor names and their metadata.
141+
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
142+
"""
143+
metadata, data_start_offset = cls.get_metadata(url)
144+
res: dict[str, tuple[str, list[int], int, int]] = {}
145+
146+
for name, meta in metadata.items():
147+
if name == "__metadata__":
148+
continue
149+
if not isinstance(meta, dict):
150+
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
151+
try:
152+
dtype = meta["dtype"]
153+
shape = meta["shape"]
154+
offset_start_relative, offset_end_relative = meta["data_offsets"]
155+
size = offset_end_relative - offset_start_relative
156+
offset_start = data_start_offset + offset_start_relative
157+
res[name] = (dtype, shape, offset_start, size)
158+
except KeyError as e:
159+
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
160+
161+
return res
162+
163+
@classmethod
164+
def get_metadata(cls, url: str) -> tuple[dict, int]:
165+
"""
166+
Get JSON metadata from a remote safetensor file.
167+
168+
Returns tuple of (metadata, data_start_offset)
169+
"""
170+
# Request first 5MB of the file (hopefully enough for metadata)
171+
read_size = 5 * 1024 * 1024
172+
raw_data = cls.get_data_by_range(url, 0, read_size)
173+
174+
# Parse header
175+
# First 8 bytes contain the metadata length as u64 little-endian
176+
if len(raw_data) < 8:
177+
raise ValueError("Not enough data to read metadata size")
178+
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
179+
180+
# Calculate the data start offset
181+
data_start_offset = 8 + metadata_length
182+
alignment = SafetensorRemote.ALIGNMENT
183+
if data_start_offset % alignment != 0:
184+
data_start_offset += alignment - (data_start_offset % alignment)
185+
186+
# Check if we have enough data to read the metadata
187+
if len(raw_data) < 8 + metadata_length:
188+
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
189+
190+
# Extract metadata bytes and parse as JSON
191+
metadata_bytes = raw_data[8:8 + metadata_length]
192+
metadata_str = metadata_bytes.decode('utf-8')
193+
try:
194+
metadata = json.loads(metadata_str)
195+
return metadata, data_start_offset
196+
except json.JSONDecodeError as e:
197+
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
198+
199+
@classmethod
200+
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
201+
"""
202+
Get raw byte data from a remote file by range.
203+
If size is not specified, it will read the entire file.
204+
"""
205+
import requests
206+
from urllib.parse import urlparse
207+
208+
parsed_url = urlparse(url)
209+
if not parsed_url.scheme or not parsed_url.netloc:
210+
raise ValueError(f"Invalid URL: {url}")
211+
212+
headers = {}
213+
if size > -1:
214+
headers = {"Range": f"bytes={start}-{start + size}"}
215+
response = requests.get(url, allow_redirects=True, headers=headers)
216+
response.raise_for_status()
217+
218+
# Get raw byte data
219+
return response.content[:size]
220+
221+
@classmethod
222+
def check_file_exist(cls, url: str) -> bool:
223+
"""
224+
Check if a file exists at the given URL.
225+
Returns True if the file exists, False otherwise.
226+
"""
227+
import requests
228+
from urllib.parse import urlparse
229+
230+
parsed_url = urlparse(url)
231+
if not parsed_url.scheme or not parsed_url.netloc:
232+
raise ValueError(f"Invalid URL: {url}")
233+
234+
try:
235+
headers = {"Range": f"bytes=0-0"}
236+
response = requests.head(url, allow_redirects=True, headers=headers)
237+
# Success (2xx) or redirect (3xx)
238+
return 200 <= response.status_code < 400
239+
except requests.RequestException:
240+
return False

0 commit comments

Comments
 (0)