Skip to content

Commit b0a9053

Browse files
committed
convert : write tensors in parallel ggml-org#12837
gguf-py : add more clarifying comments for multi-thread writes Merge branch 'master' into compilade/parallel-convert gguf-py : use ThreadPoolExecutor when writing tensors - gguf-py : handle (limited) retries for remote tensors Original author : @compilade Merge branch 'compilade/parallel-convert' into NXS_Llama.cpp
1 parent 9d60f53 commit b0a9053

File tree

4 files changed

+176
-33
lines changed

4 files changed

+176
-33
lines changed

convert_hf_to_gguf.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
8585
use_temp_file: bool = False, eager: bool = False,
8686
metadata_override: Path | None = None, model_name: str | None = None,
8787
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
88-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
88+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
89+
thread_count: int = 2):
8990
if type(self) is ModelBase or \
9091
type(self) is TextModel or \
9192
type(self) is MmprojModel:
@@ -116,6 +117,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
116117
if not self.is_safetensors:
117118
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
118119
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
120+
# self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
121+
# self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
119122
self.tensor_names = None
120123
self.metadata_override = metadata_override
121124
self.model_name = model_name
@@ -134,7 +137,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
134137

135138
# Configure GGUF Writer
136139
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
137-
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
140+
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
141+
thread_count=thread_count)
138142

139143
@classmethod
140144
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
@@ -6501,6 +6505,10 @@ def parse_args() -> argparse.Namespace:
65016505
"--mmproj", action="store_true",
65026506
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
65036507
)
6508+
parser.add_argument(
6509+
"-t", "--threads", type=int, default=2,
6510+
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this. Defaults to 2.",
6511+
)
65046512

65056513
args = parser.parse_args()
65066514
if not args.print_supported_models and args.model is None:
@@ -6617,6 +6625,7 @@ def main() -> None:
66176625
hparams = ModelBase.load_hparams(dir_model)
66186626
model_architecture = get_model_architecture(hparams, model_type)
66196627
logger.info(f"Model architecture: {model_architecture}")
6628+
# model_architecture = hparams["architectures"][0]
66206629
try:
66216630
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
66226631
except NotImplementedError:
@@ -6630,7 +6639,8 @@ def main() -> None:
66306639
split_max_tensors=args.split_max_tensors,
66316640
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
66326641
small_first_shard=args.no_tensor_first_split,
6633-
remote_hf_model_id=str(args.model) if args.remote else None)
6642+
remote_hf_model_id=str(args.model) if args.remote else None,
6643+
thread_count=args.threads)
66346644

66356645
if args.vocab_only:
66366646
logger.info("Exporting model vocab...")

gguf-py/gguf/gguf_writer.py

Lines changed: 117 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import shutil
66
import struct
77
import tempfile
8+
import threading
89
from dataclasses import dataclass
910
from enum import Enum, auto
1011
from math import prod
1112
from pathlib import Path
1213
from io import BufferedWriter
1314
from typing import IO, Any, Sequence, Mapping
1415
from string import ascii_letters, digits
16+
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
1517

1618
import numpy as np
1719

@@ -61,8 +63,63 @@ class WriterState(Enum):
6163
WEIGHTS = auto()
6264

6365

66+
# To close files which were opened in thread-local context
67+
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
68+
# ref: https://github.com/python/cpython/issues/89502
69+
class _ThreadedOpenFiles:
70+
files: dict[Path, BufferedWriter]
71+
72+
def __init__(self):
73+
self.files = {}
74+
75+
def __del__(self):
76+
for file in self.files.values():
77+
file.close()
78+
79+
def __getitem__(self, key: Path, /) -> BufferedWriter:
80+
if key not in self.files:
81+
self.files[key] = open(key, "r+b")
82+
return self.files[key]
83+
84+
@classmethod
85+
def init_thread_local(cls, local_data):
86+
local_data.open_files = _ThreadedOpenFiles()
87+
88+
89+
# Exit quickly instead of waiting
90+
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
91+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
92+
del exc_type, exc_val, exc_tb
93+
self.shutdown(wait=False, cancel_futures=True)
94+
return False
95+
96+
97+
@dataclass
98+
class _ThreadedTensorWriteInfo:
99+
filename: Path
100+
offset: int
101+
post_pad: int
102+
tensor: np.ndarray
103+
bar: Any | None # optional tqdm progress bar
104+
105+
def write_chunk(self, open_files: _ThreadedOpenFiles):
106+
# This is called from a thread pool,
107+
# and each thread should have its own file handle per output file
108+
# so that they can have different seek locations.
109+
f = open_files[self.filename]
110+
111+
f.seek(self.offset)
112+
f.write(self.tensor.data)
113+
if self.post_pad > 0:
114+
f.write(bytes([0] * self.post_pad))
115+
if self.bar is not None:
116+
self.bar.update(self.tensor.nbytes)
117+
118+
64119
class GGUFWriter:
65120
fout: list[BufferedWriter] | None
121+
filenames: list[Path] | None
122+
thread_count: int
66123
path: Path | None
67124
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
68125
tensors: list[dict[str, TensorInfo]]
@@ -84,7 +141,8 @@ class GGUFWriter:
84141

85142
def __init__(
86143
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
87-
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
144+
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
145+
thread_count: int = 2,
88146
):
89147
self.fout = None
90148
self.path = Path(path) if path else None
@@ -99,6 +157,7 @@ def __init__(
99157
self.split_max_size = split_max_size
100158
self.dry_run = dry_run
101159
self.small_first_shard = small_first_shard
160+
self.thread_count = thread_count
102161
logger.info("gguf: This GGUF file is for {0} Endian only".format(
103162
"Big" if self.endianess == GGUFEndian.BIG else "Little",
104163
))
@@ -174,6 +233,7 @@ def open_output_file(self, path: Path | None = None) -> None:
174233

175234
if self.path is not None:
176235
filenames = self.print_plan()
236+
self.filenames = filenames
177237
self.fout = [open(filename, "wb") for filename in filenames]
178238
self.state = WriterState.EMPTY
179239

@@ -425,40 +485,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
425485
self.write_ti_data_to_file()
426486

427487
assert self.fout is not None
488+
assert self.filenames is not None
428489

429490
for fout in self.fout:
430491
self.write_padding(fout, fout.tell())
431492

432493
if self.temp_file is None:
433-
shard_bar = None
434494
bar = None
495+
# Initial file offsets before writing the tensor data
496+
offsets: list[int] = [fout.tell() for fout in self.fout]
435497

436498
if progress:
499+
# TODO: add back the shard bar to show which shard is being written when single-threaded
437500
from tqdm import tqdm
438501

439502
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
440503

441-
if len(self.fout) > 1:
442-
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
443504
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
444505

445-
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
446-
if shard_bar is not None:
447-
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
448-
total = sum(ti.nbytes for ti in tensors.values())
449-
shard_bar.reset(total=(total if total > 0 else None))
450-
451-
# relying on the fact that Python dicts preserve insertion order (since 3.7)
452-
for ti in tensors.values():
453-
assert ti.tensor is not None # can only iterate once over the tensors
454-
assert ti.tensor.nbytes == ti.nbytes
455-
ti.tensor.tofile(fout)
456-
if shard_bar is not None:
457-
shard_bar.update(ti.nbytes)
458-
if bar is not None:
459-
bar.update(ti.nbytes)
460-
self.write_padding(fout, ti.nbytes)
461-
ti.tensor = None
506+
# Allow opening the files only once per worker
507+
local_data = threading.local()
508+
509+
# Unit of work
510+
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
511+
tensor.write_chunk(local_data.open_files)
512+
513+
with _InterruptibleThreadPoolExecutor(
514+
max_workers=self.thread_count,
515+
initializer=_ThreadedOpenFiles.init_thread_local,
516+
initargs=(local_data,),
517+
) as executor:
518+
519+
futures: list[Future] = []
520+
521+
# Fill the tensor queue with all the pending tensor writes
522+
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
523+
offset = offsets[i]
524+
525+
# relying on the fact that Python dicts preserve insertion order (since 3.7)
526+
for ti in tensors.values():
527+
assert ti.tensor is not None # can only iterate once over the tensors
528+
assert ti.tensor.nbytes == ti.nbytes
529+
start_offset = offset
530+
nbytes = ti.tensor.nbytes
531+
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
532+
padding = offset - (start_offset + nbytes)
533+
futures.append(
534+
executor.submit(
535+
thread_write_tensor,
536+
_ThreadedTensorWriteInfo(
537+
filename=filename,
538+
offset=start_offset,
539+
post_pad=padding,
540+
tensor=ti.tensor,
541+
bar=bar,
542+
),
543+
)
544+
)
545+
ti.tensor = None # avoid keeping a reference to written tensors
546+
547+
# FIXME: there's still some weird behavior with KeyboardInterrupt
548+
# not being able to interrupt a future mid-execution
549+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
550+
exc = None
551+
if any(f for f in done
552+
if not f.cancelled() and (exc := f.exception()) is not None):
553+
raise RuntimeError("Error writing tensors") from exc
554+
elif len(not_done) != 0:
555+
raise RuntimeError("Not all tensors were written")
556+
557+
del local_data
462558
else:
463559
self.temp_file.seek(0)
464560

gguf-py/gguf/lazy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,9 @@ def tofile(self, *args, **kwargs):
220220
eager = LazyNumpyTensor.to_eager(self)
221221
return eager.tofile(*args, **kwargs)
222222

223+
@property
224+
def data(self):
225+
eager = LazyNumpyTensor.to_eager(self)
226+
return eager.data
227+
223228
# TODO: __array_function__

gguf-py/gguf/utility.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
import os
77
import json
8+
import time
9+
import logging
10+
11+
import requests
12+
from urllib.parse import urlparse
13+
14+
15+
logger = logging.getLogger(__name__)
816

917

1018
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -75,16 +83,38 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
7583

7684
@dataclass
7785
class RemoteTensor:
86+
name: str
7887
dtype: str
7988
shape: tuple[int, ...]
8089
offset_start: int
8190
size: int
8291
url: str
8392

8493
def data(self) -> bytearray:
85-
# TODO: handle request errors (maybe with limited retries?)
86-
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
87-
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
94+
data = None
95+
MAX_RETRIES = 8
96+
for i in range(MAX_RETRIES):
97+
try:
98+
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
99+
data = bytearray(
100+
SafetensorRemote.get_data_by_range(
101+
url=self.url, start=self.offset_start, size=self.size
102+
)
103+
)
104+
except (
105+
requests.exceptions.ChunkedEncodingError,
106+
requests.exceptions.ContentDecodingError,
107+
requests.exceptions.ConnectionError,
108+
) as e:
109+
if i == MAX_RETRIES - 1:
110+
raise RuntimeError(f"Failed to download tensor {self.name}") from e
111+
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
112+
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
113+
continue
114+
115+
if data is None:
116+
raise RuntimeError(f"Failed to download tensor {self.name}")
117+
88118
return data
89119

90120

@@ -169,7 +199,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
169199
offset_start_relative, offset_end_relative = meta["data_offsets"]
170200
size = offset_end_relative - offset_start_relative
171201
offset_start = data_start_offset + offset_start_relative
172-
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
202+
res[name] = RemoteTensor(
203+
name=name,
204+
dtype=dtype,
205+
shape=tuple(shape),
206+
offset_start=offset_start,
207+
size=size,
208+
url=url,
209+
)
173210
except KeyError as e:
174211
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
175212

@@ -217,8 +254,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
217254
Get raw byte data from a remote file by range.
218255
If size is not specified, it will read the entire file.
219256
"""
220-
import requests
221-
from urllib.parse import urlparse
222257

223258
parsed_url = urlparse(url)
224259
if not parsed_url.scheme or not parsed_url.netloc:
@@ -239,9 +274,6 @@ def check_file_exist(cls, url: str) -> bool:
239274
Check if a file exists at the given URL.
240275
Returns True if the file exists, False otherwise.
241276
"""
242-
import requests
243-
from urllib.parse import urlparse
244-
245277
parsed_url = urlparse(url)
246278
if not parsed_url.scheme or not parsed_url.netloc:
247279
raise ValueError(f"Invalid URL: {url}")

0 commit comments

Comments
 (0)