Skip to content

Commit 3a3682d

Browse files
committed
convert : allow using lazy remote tensors
It's a bit slow for now since everything is blocking and single-threaded.
1 parent 08ecbbe commit 3a3682d

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

convert_hf_to_gguf.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7373
use_temp_file: bool = False, eager: bool = False,
7474
metadata_override: Path | None = None, model_name: str | None = None,
7575
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
76-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
76+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
7777
if type(self) is Model:
7878
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
7979

@@ -83,11 +83,23 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
8383
self.is_big_endian = is_big_endian
8484
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
8585
self.use_temp_file = use_temp_file
86-
self.lazy = not eager
87-
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
88-
self.is_safetensors = len(self.part_names) > 0
89-
if not self.is_safetensors:
90-
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
86+
self.lazy = not eager or (remote_hf_model_id is not None)
87+
if remote_hf_model_id is not None:
88+
self.is_safetensors = True
89+
90+
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
91+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
92+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
93+
self.tensor_names = set(name for name in remote_tensors.keys())
94+
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
95+
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
96+
97+
self.get_tensors = get_remote_tensors
98+
else:
99+
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
100+
self.is_safetensors = len(self.part_names) > 0
101+
if not self.is_safetensors:
102+
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
91103
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
92104
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
93105
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@@ -5393,6 +5405,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
53935405
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
53945406
return cast(torch.Tensor, lazy)
53955407

5408+
@classmethod
5409+
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
5410+
dtype = cls._dtype_str_map[remote_tensor.dtype]
5411+
shape = remote_tensor.shape
5412+
meta = cls.meta_with_dtype_and_shape(dtype, shape)
5413+
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
5414+
return cast(torch.Tensor, lazy)
5415+
53965416
@classmethod
53975417
def __torch_function__(cls, func, types, args=(), kwargs=None):
53985418
del types # unused
@@ -5516,8 +5536,9 @@ def main() -> None:
55165536

55175537
if args.remote:
55185538
from huggingface_hub import snapshot_download
5539+
args.remote = str(dir_model)
55195540
local_dir = snapshot_download(
5520-
repo_id=str(dir_model),
5541+
repo_id=args.remote,
55215542
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
55225543
dir_model = Path(local_dir)
55235544
logger.info(f"Downloaded config and tokenizer to {local_dir}")
@@ -5569,7 +5590,7 @@ def main() -> None:
55695590
metadata_override=args.metadata, model_name=args.model_name,
55705591
split_max_tensors=args.split_max_tensors,
55715592
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
5572-
small_first_shard=args.no_tensor_first_split)
5593+
small_first_shard=args.no_tensor_first_split, remote_hf_model_id=args.remote or None)
55735594

55745595
if args.vocab_only:
55755596
logger.info("Exporting model vocab...")

gguf-py/gguf/utility.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import Literal
45

56
import json
@@ -71,6 +72,20 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
7172
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
7273

7374

75+
@dataclass
76+
class RemoteTensor:
77+
dtype: str
78+
shape: tuple[int, ...]
79+
offset_start: int
80+
size: int
81+
url: str
82+
83+
def data(self) -> bytes:
84+
# TODO: handle request errors (maybe with limited retries?)
85+
data = SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)
86+
return data
87+
88+
7489
class SafetensorRemote:
7590
"""
7691
Uility class to handle remote safetensor files.
@@ -94,7 +109,7 @@ class SafetensorRemote:
94109
ALIGNMENT = 8 # bytes
95110

96111
@classmethod
97-
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[int], int, int, str]]:
112+
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
98113
"""
99114
Get list of tensors from a Hugging Face model repository.
100115
@@ -105,10 +120,7 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i
105120
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
106121
if is_single_file:
107122
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
108-
tensors: dict[str, tuple[str, list[int], int, int, str]] = {}
109-
for key, val in cls.get_list_tensors(url).items():
110-
tensors[key] = (*val, url) # populate the url
111-
return tensors
123+
return cls.get_list_tensors(url)
112124

113125
# case 2: model has multiple files
114126
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
@@ -124,25 +136,25 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, tuple[str, list[i
124136
all_files = list(set(weight_map.values()))
125137
all_files.sort() # make sure we load shard files in order
126138
# get the list of tensors
127-
tensors = {}
139+
tensors: dict[str, RemoteTensor] = {}
128140
for file in all_files:
129141
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
130142
for key, val in cls.get_list_tensors(url).items():
131-
tensors[key] = (*val, url) # populate the url
143+
tensors[key] = val
132144
return tensors
133145

134146
raise ValueError(f"Model {model_id} does not have any safetensor files")
135147

136148
@classmethod
137-
def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]]:
149+
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
138150
"""
139151
Get list of tensors from a remote safetensor file.
140152
141153
Returns a dictionary of tensor names and their metadata.
142154
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
143155
"""
144156
metadata, data_start_offset = cls.get_metadata(url)
145-
res: dict[str, tuple[str, list[int], int, int]] = {}
157+
res: dict[str, RemoteTensor] = {}
146158

147159
for name, meta in metadata.items():
148160
if name == "__metadata__":
@@ -155,7 +167,7 @@ def get_list_tensors(cls, url: str) -> dict[str, tuple[str, list[int], int, int]
155167
offset_start_relative, offset_end_relative = meta["data_offsets"]
156168
size = offset_end_relative - offset_start_relative
157169
offset_start = data_start_offset + offset_start_relative
158-
res[name] = (dtype, shape, offset_start, size)
170+
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
159171
except KeyError as e:
160172
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
161173

0 commit comments

Comments
 (0)