5
5
import shutil
6
6
import struct
7
7
import tempfile
8
+ import threading
8
9
from dataclasses import dataclass
9
10
from enum import Enum , auto
10
11
from math import prod
11
12
from pathlib import Path
12
13
from io import BufferedWriter
13
14
from typing import IO , Any , Sequence , Mapping
14
15
from string import ascii_letters , digits
16
+ from concurrent .futures import FIRST_EXCEPTION , Future , ThreadPoolExecutor , wait
15
17
16
18
import numpy as np
17
19
@@ -61,8 +63,63 @@ class WriterState(Enum):
61
63
WEIGHTS = auto ()
62
64
63
65
66
+ # To close files which were opened in thread-local context
67
+ # Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
68
+ # ref: https://github.com/python/cpython/issues/89502
69
+ class _ThreadedOpenFiles :
70
+ files : dict [Path , BufferedWriter ]
71
+
72
+ def __init__ (self ):
73
+ self .files = {}
74
+
75
+ def __del__ (self ):
76
+ for file in self .files .values ():
77
+ file .close ()
78
+
79
+ def __getitem__ (self , key : Path , / ) -> BufferedWriter :
80
+ if key not in self .files :
81
+ self .files [key ] = open (key , "r+b" )
82
+ return self .files [key ]
83
+
84
+ @classmethod
85
+ def init_thread_local (cls , local_data ):
86
+ local_data .open_files = _ThreadedOpenFiles ()
87
+
88
+
89
+ # Exit quickly instead of waiting
90
+ class _InterruptibleThreadPoolExecutor (ThreadPoolExecutor ):
91
+ def __exit__ (self , exc_type , exc_val , exc_tb ) -> bool | None :
92
+ del exc_type , exc_val , exc_tb
93
+ self .shutdown (wait = False , cancel_futures = True )
94
+ return False
95
+
96
+
97
+ @dataclass
98
+ class _ThreadedTensorWriteInfo :
99
+ filename : Path
100
+ offset : int
101
+ post_pad : int
102
+ tensor : np .ndarray
103
+ bar : Any | None # optional tqdm progress bar
104
+
105
+ def write_chunk (self , open_files : _ThreadedOpenFiles ):
106
+ # This is called from a thread pool,
107
+ # and each thread should have its own file handle per output file
108
+ # so that they can have different seek locations.
109
+ f = open_files [self .filename ]
110
+
111
+ f .seek (self .offset )
112
+ f .write (self .tensor .data )
113
+ if self .post_pad > 0 :
114
+ f .write (bytes ([0 ] * self .post_pad ))
115
+ if self .bar is not None :
116
+ self .bar .update (self .tensor .nbytes )
117
+
118
+
64
119
class GGUFWriter :
65
120
fout : list [BufferedWriter ] | None
121
+ filenames : list [Path ] | None
122
+ thread_count : int
66
123
path : Path | None
67
124
temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
68
125
tensors : list [dict [str , TensorInfo ]]
@@ -84,7 +141,8 @@ class GGUFWriter:
84
141
85
142
def __init__ (
86
143
self , path : os .PathLike [str ] | str | None , arch : str , use_temp_file : bool = False , endianess : GGUFEndian = GGUFEndian .LITTLE ,
87
- split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False
144
+ split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False ,
145
+ thread_count : int = 2 ,
88
146
):
89
147
self .fout = None
90
148
self .path = Path (path ) if path else None
@@ -99,6 +157,7 @@ def __init__(
99
157
self .split_max_size = split_max_size
100
158
self .dry_run = dry_run
101
159
self .small_first_shard = small_first_shard
160
+ self .thread_count = thread_count
102
161
logger .info ("gguf: This GGUF file is for {0} Endian only" .format (
103
162
"Big" if self .endianess == GGUFEndian .BIG else "Little" ,
104
163
))
@@ -174,6 +233,7 @@ def open_output_file(self, path: Path | None = None) -> None:
174
233
175
234
if self .path is not None :
176
235
filenames = self .print_plan ()
236
+ self .filenames = filenames
177
237
self .fout = [open (filename , "wb" ) for filename in filenames ]
178
238
self .state = WriterState .EMPTY
179
239
@@ -425,40 +485,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
425
485
self .write_ti_data_to_file ()
426
486
427
487
assert self .fout is not None
488
+ assert self .filenames is not None
428
489
429
490
for fout in self .fout :
430
491
self .write_padding (fout , fout .tell ())
431
492
432
493
if self .temp_file is None :
433
- shard_bar = None
434
494
bar = None
495
+ # Initial file offsets before writing the tensor data
496
+ offsets : list [int ] = [fout .tell () for fout in self .fout ]
435
497
436
498
if progress :
499
+ # TODO: add back the shard bar to show which shard is being written when single-threaded
437
500
from tqdm import tqdm
438
501
439
502
total_bytes = sum (ti .nbytes for t in self .tensors for ti in t .values ())
440
503
441
- if len (self .fout ) > 1 :
442
- shard_bar = tqdm (desc = f"Shard (0/{ len (self .fout )} )" , total = None , unit = "byte" , unit_scale = True )
443
504
bar = tqdm (desc = "Writing" , total = total_bytes , unit = "byte" , unit_scale = True )
444
505
445
- for i , (fout , tensors ) in enumerate (zip (self .fout , self .tensors )):
446
- if shard_bar is not None :
447
- shard_bar .set_description (f"Shard ({ i + 1 } /{ len (self .fout )} )" )
448
- total = sum (ti .nbytes for ti in tensors .values ())
449
- shard_bar .reset (total = (total if total > 0 else None ))
450
-
451
- # relying on the fact that Python dicts preserve insertion order (since 3.7)
452
- for ti in tensors .values ():
453
- assert ti .tensor is not None # can only iterate once over the tensors
454
- assert ti .tensor .nbytes == ti .nbytes
455
- ti .tensor .tofile (fout )
456
- if shard_bar is not None :
457
- shard_bar .update (ti .nbytes )
458
- if bar is not None :
459
- bar .update (ti .nbytes )
460
- self .write_padding (fout , ti .nbytes )
461
- ti .tensor = None
506
+ # Allow opening the files only once per worker
507
+ local_data = threading .local ()
508
+
509
+ # Unit of work
510
+ def thread_write_tensor (tensor : _ThreadedTensorWriteInfo ):
511
+ tensor .write_chunk (local_data .open_files )
512
+
513
+ with _InterruptibleThreadPoolExecutor (
514
+ max_workers = self .thread_count ,
515
+ initializer = _ThreadedOpenFiles .init_thread_local ,
516
+ initargs = (local_data ,),
517
+ ) as executor :
518
+
519
+ futures : list [Future ] = []
520
+
521
+ # Fill the tensor queue with all the pending tensor writes
522
+ for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
523
+ offset = offsets [i ]
524
+
525
+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
526
+ for ti in tensors .values ():
527
+ assert ti .tensor is not None # can only iterate once over the tensors
528
+ assert ti .tensor .nbytes == ti .nbytes
529
+ start_offset = offset
530
+ nbytes = ti .tensor .nbytes
531
+ offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
532
+ padding = offset - (start_offset + nbytes )
533
+ futures .append (
534
+ executor .submit (
535
+ thread_write_tensor ,
536
+ _ThreadedTensorWriteInfo (
537
+ filename = filename ,
538
+ offset = start_offset ,
539
+ post_pad = padding ,
540
+ tensor = ti .tensor ,
541
+ bar = bar ,
542
+ ),
543
+ )
544
+ )
545
+ ti .tensor = None # avoid keeping a reference to written tensors
546
+
547
+ # FIXME: there's still some weird behavior with KeyboardInterrupt
548
+ # not being able to interrupt a future mid-execution
549
+ done , not_done = wait (futures , return_when = FIRST_EXCEPTION )
550
+ exc = None
551
+ if any (f for f in done
552
+ if not f .cancelled () and (exc := f .exception ()) is not None ):
553
+ raise RuntimeError ("Error writing tensors" ) from exc
554
+ elif len (not_done ) != 0 :
555
+ raise RuntimeError ("Not all tensors were written" )
556
+
557
+ del local_data
462
558
else :
463
559
self .temp_file .seek (0 )
464
560
0 commit comments