Skip to content

Commit 06e1d31

Browse files
committed
convert : write tensors in parallel
1 parent b32efad commit 06e1d31

File tree

3 files changed

+95
-19
lines changed

3 files changed

+95
-19
lines changed

convert_hf_to_gguf.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7373
use_temp_file: bool = False, eager: bool = False,
7474
metadata_override: Path | None = None, model_name: str | None = None,
7575
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
76-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
76+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, thread_count: int = 2):
7777
if type(self) is Model:
7878
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
7979

@@ -109,7 +109,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
109109

110110
# Configure GGUF Writer
111111
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
112-
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
112+
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
113+
thread_count=thread_count)
113114

114115
@classmethod
115116
def __init_subclass__(cls):
@@ -5470,6 +5471,10 @@ def parse_args() -> argparse.Namespace:
54705471
"--print-supported-models", action="store_true",
54715472
help="Print the supported models"
54725473
)
5474+
parser.add_argument(
5475+
"-t", "--threads", type=int, default=2,
5476+
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this.",
5477+
)
54735478

54745479
args = parser.parse_args()
54755480
if not args.print_supported_models and args.model is None:
@@ -5554,7 +5559,7 @@ def main() -> None:
55545559
metadata_override=args.metadata, model_name=args.model_name,
55555560
split_max_tensors=args.split_max_tensors,
55565561
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
5557-
small_first_shard=args.no_tensor_first_split)
5562+
small_first_shard=args.no_tensor_first_split, thread_count=args.threads)
55585563

55595564
if args.vocab_only:
55605565
logger.info("Exporting model vocab...")

gguf-py/gguf/gguf_writer.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import shutil
66
import struct
77
import tempfile
8+
import threading
89
from dataclasses import dataclass
910
from enum import Enum, auto
1011
from math import prod
1112
from pathlib import Path
13+
from queue import Empty, Queue
1214
from io import BufferedWriter
1315
from typing import IO, Any, Sequence, Mapping
1416
from string import ascii_letters, digits
@@ -60,8 +62,31 @@ class WriterState(Enum):
6062
WEIGHTS = auto()
6163

6264

65+
@dataclass
66+
class TensorWriteInfo:
67+
filename: Path
68+
offset: int
69+
post_pad: int
70+
tensor: np.ndarray
71+
bar: Any | None
72+
73+
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
74+
if self.filename not in open_files:
75+
open_files[self.filename] = open(self.filename, "r+b")
76+
f = open_files[self.filename]
77+
78+
f.seek(self.offset)
79+
f.write(self.tensor.data)
80+
if self.post_pad > 0:
81+
f.write(bytes([0] * self.post_pad))
82+
if self.bar is not None:
83+
self.bar.update(self.tensor.nbytes)
84+
85+
6386
class GGUFWriter:
6487
fout: list[BufferedWriter] | None
88+
filenames: list[Path] | None
89+
thread_count: int
6590
path: Path | None
6691
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
6792
tensors: list[dict[str, TensorInfo]]
@@ -83,7 +108,8 @@ class GGUFWriter:
83108

84109
def __init__(
85110
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
86-
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
111+
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
112+
thread_count: int = 2,
87113
):
88114
self.fout = None
89115
self.path = Path(path) if path else None
@@ -98,6 +124,7 @@ def __init__(
98124
self.split_max_size = split_max_size
99125
self.dry_run = dry_run
100126
self.small_first_shard = small_first_shard
127+
self.thread_count = thread_count
101128
logger.info("gguf: This GGUF file is for {0} Endian only".format(
102129
"Big" if self.endianess == GGUFEndian.BIG else "Little",
103130
))
@@ -173,6 +200,7 @@ def open_output_file(self, path: Path | None = None) -> None:
173200

174201
if self.path is not None:
175202
filenames = self.print_plan()
203+
self.filenames = filenames
176204
self.fout = [open(filename, "wb") for filename in filenames]
177205
self.state = WriterState.EMPTY
178206

@@ -424,40 +452,78 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
424452
self.write_ti_data_to_file()
425453

426454
assert self.fout is not None
455+
assert self.filenames is not None
427456

428457
for fout in self.fout:
429458
self.write_padding(fout, fout.tell())
430459

431460
if self.temp_file is None:
432-
shard_bar = None
433461
bar = None
462+
# Distribute writing the tensors between multiple threads
463+
tensor_queue: Queue[TensorWriteInfo] = Queue()
464+
465+
offsets: list[int] = [fout.tell() for fout in self.fout]
434466

435467
if progress:
468+
# TODO: add back the shard bar to show which shard is being written when single-threaded
436469
from tqdm import tqdm
437470

438471
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
439472

440-
if len(self.fout) > 1:
441-
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
442473
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
443474

444-
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
445-
if shard_bar is not None:
446-
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
447-
total = sum(ti.nbytes for ti in tensors.values())
448-
shard_bar.reset(total=(total if total > 0 else None))
475+
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
476+
offset = offsets[i]
449477

450478
# relying on the fact that Python dicts preserve insertion order (since 3.7)
451479
for ti in tensors.values():
452480
assert ti.tensor is not None # can only iterate once over the tensors
453481
assert ti.tensor.nbytes == ti.nbytes
454-
ti.tensor.tofile(fout)
455-
if shard_bar is not None:
456-
shard_bar.update(ti.nbytes)
457-
if bar is not None:
458-
bar.update(ti.nbytes)
459-
self.write_padding(fout, ti.nbytes)
460-
ti.tensor = None
482+
start_offset = offset
483+
nbytes = ti.tensor.nbytes
484+
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
485+
padding = offset - (start_offset + nbytes)
486+
tensor_queue.put(
487+
TensorWriteInfo(
488+
filename=filename,
489+
offset=start_offset,
490+
post_pad=padding,
491+
tensor=ti.tensor,
492+
bar=bar,
493+
)
494+
)
495+
ti.tensor = None # avoid keeping a reference to written tensors
496+
497+
# Write tensors in parallel
498+
# TODO: total tensor size limit for the running threads
499+
def write_tensors_from_thread(queue: Queue[TensorWriteInfo]):
500+
open_files: dict[Path, BufferedWriter] = {}
501+
try:
502+
while t := queue.get_nowait():
503+
t.write_chunk(open_files)
504+
del t
505+
queue.task_done()
506+
except Empty:
507+
pass
508+
509+
for f in open_files.values():
510+
f.close()
511+
512+
threads = [
513+
threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
514+
for _ in range(self.thread_count)
515+
]
516+
517+
for t in threads:
518+
t.start()
519+
520+
# NOTE: thread joining has weird interactions with KeyboardInterrupt,
521+
# so waiting for the queue to be "done" first.
522+
tensor_queue.join()
523+
524+
for t in threads:
525+
t.join()
526+
461527
else:
462528
self.temp_file.seek(0)
463529

gguf-py/gguf/lazy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,9 @@ def tofile(self, *args, **kwargs):
220220
eager = LazyNumpyTensor.to_eager(self)
221221
return eager.tofile(*args, **kwargs)
222222

223+
@property
224+
def data(self):
225+
eager = LazyNumpyTensor.to_eager(self)
226+
return eager.data
227+
223228
# TODO: __array_function__

0 commit comments

Comments
 (0)