5
5
import shutil
6
6
import struct
7
7
import tempfile
8
+ import threading
8
9
from dataclasses import dataclass
9
10
from enum import Enum , auto
10
11
from math import prod
11
12
from pathlib import Path
13
+ from queue import Empty , Queue
12
14
from io import BufferedWriter
13
15
from typing import IO , Any , Sequence , Mapping
14
16
from string import ascii_letters , digits
@@ -60,8 +62,31 @@ class WriterState(Enum):
60
62
WEIGHTS = auto ()
61
63
62
64
65
+ @dataclass
66
+ class TensorWriteInfo :
67
+ filename : Path
68
+ offset : int
69
+ post_pad : int
70
+ tensor : np .ndarray
71
+ bar : Any | None
72
+
73
+ def write_chunk (self , open_files : dict [Path , BufferedWriter ]):
74
+ if self .filename not in open_files :
75
+ open_files [self .filename ] = open (self .filename , "r+b" )
76
+ f = open_files [self .filename ]
77
+
78
+ f .seek (self .offset )
79
+ f .write (self .tensor .data )
80
+ if self .post_pad > 0 :
81
+ f .write (bytes ([0 ] * self .post_pad ))
82
+ if self .bar is not None :
83
+ self .bar .update (self .tensor .nbytes )
84
+
85
+
63
86
class GGUFWriter :
64
87
fout : list [BufferedWriter ] | None
88
+ filenames : list [Path ] | None
89
+ thread_count : int
65
90
path : Path | None
66
91
temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
67
92
tensors : list [dict [str , TensorInfo ]]
@@ -83,7 +108,8 @@ class GGUFWriter:
83
108
84
109
def __init__ (
85
110
self , path : os .PathLike [str ] | str | None , arch : str , use_temp_file : bool = False , endianess : GGUFEndian = GGUFEndian .LITTLE ,
86
- split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False
111
+ split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False ,
112
+ thread_count : int = 2 ,
87
113
):
88
114
self .fout = None
89
115
self .path = Path (path ) if path else None
@@ -98,6 +124,7 @@ def __init__(
98
124
self .split_max_size = split_max_size
99
125
self .dry_run = dry_run
100
126
self .small_first_shard = small_first_shard
127
+ self .thread_count = thread_count
101
128
logger .info ("gguf: This GGUF file is for {0} Endian only" .format (
102
129
"Big" if self .endianess == GGUFEndian .BIG else "Little" ,
103
130
))
@@ -173,6 +200,7 @@ def open_output_file(self, path: Path | None = None) -> None:
173
200
174
201
if self .path is not None :
175
202
filenames = self .print_plan ()
203
+ self .filenames = filenames
176
204
self .fout = [open (filename , "wb" ) for filename in filenames ]
177
205
self .state = WriterState .EMPTY
178
206
@@ -424,40 +452,78 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
424
452
self .write_ti_data_to_file ()
425
453
426
454
assert self .fout is not None
455
+ assert self .filenames is not None
427
456
428
457
for fout in self .fout :
429
458
self .write_padding (fout , fout .tell ())
430
459
431
460
if self .temp_file is None :
432
- shard_bar = None
433
461
bar = None
462
+ # Distribute writing the tensors between multiple threads
463
+ tensor_queue : Queue [TensorWriteInfo ] = Queue ()
464
+
465
+ offsets : list [int ] = [fout .tell () for fout in self .fout ]
434
466
435
467
if progress :
468
+ # TODO: add back the shard bar to show which shard is being written when single-threaded
436
469
from tqdm import tqdm
437
470
438
471
total_bytes = sum (ti .nbytes for t in self .tensors for ti in t .values ())
439
472
440
- if len (self .fout ) > 1 :
441
- shard_bar = tqdm (desc = f"Shard (0/{ len (self .fout )} )" , total = None , unit = "byte" , unit_scale = True )
442
473
bar = tqdm (desc = "Writing" , total = total_bytes , unit = "byte" , unit_scale = True )
443
474
444
- for i , (fout , tensors ) in enumerate (zip (self .fout , self .tensors )):
445
- if shard_bar is not None :
446
- shard_bar .set_description (f"Shard ({ i + 1 } /{ len (self .fout )} )" )
447
- total = sum (ti .nbytes for ti in tensors .values ())
448
- shard_bar .reset (total = (total if total > 0 else None ))
475
+ for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
476
+ offset = offsets [i ]
449
477
450
478
# relying on the fact that Python dicts preserve insertion order (since 3.7)
451
479
for ti in tensors .values ():
452
480
assert ti .tensor is not None # can only iterate once over the tensors
453
481
assert ti .tensor .nbytes == ti .nbytes
454
- ti .tensor .tofile (fout )
455
- if shard_bar is not None :
456
- shard_bar .update (ti .nbytes )
457
- if bar is not None :
458
- bar .update (ti .nbytes )
459
- self .write_padding (fout , ti .nbytes )
460
- ti .tensor = None
482
+ start_offset = offset
483
+ nbytes = ti .tensor .nbytes
484
+ offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
485
+ padding = offset - (start_offset + nbytes )
486
+ tensor_queue .put (
487
+ TensorWriteInfo (
488
+ filename = filename ,
489
+ offset = start_offset ,
490
+ post_pad = padding ,
491
+ tensor = ti .tensor ,
492
+ bar = bar ,
493
+ )
494
+ )
495
+ ti .tensor = None # avoid keeping a reference to written tensors
496
+
497
+ # Write tensors in parallel
498
+ # TODO: total tensor size limit for the running threads
499
+ def write_tensors_from_thread (queue : Queue [TensorWriteInfo ]):
500
+ open_files : dict [Path , BufferedWriter ] = {}
501
+ try :
502
+ while t := queue .get_nowait ():
503
+ t .write_chunk (open_files )
504
+ del t
505
+ queue .task_done ()
506
+ except Empty :
507
+ pass
508
+
509
+ for f in open_files .values ():
510
+ f .close ()
511
+
512
+ threads = [
513
+ threading .Thread (target = write_tensors_from_thread , args = (tensor_queue ,))
514
+ for _ in range (self .thread_count )
515
+ ]
516
+
517
+ for t in threads :
518
+ t .start ()
519
+
520
+ # NOTE: thread joining has weird interactions with KeyboardInterrupt,
521
+ # so waiting for the queue to be "done" first.
522
+ tensor_queue .join ()
523
+
524
+ for t in threads :
525
+ t .join ()
526
+
461
527
else :
462
528
self .temp_file .seek (0 )
463
529
0 commit comments