@@ -63,14 +63,17 @@ class WriterState(Enum):
63
63
64
64
65
65
@dataclass
66
- class TensorWriteInfo :
66
+ class ThreadedTensorWriteInfo :
67
67
filename : Path
68
68
offset : int
69
69
post_pad : int
70
70
tensor : np .ndarray
71
- bar : Any | None
71
+ bar : Any | None # optional tqdm progress bar
72
72
73
73
def write_chunk (self , open_files : dict [Path , BufferedWriter ]):
74
+ # This is called from a thread pool,
75
+ # and each thread should have its own file handle per output file
76
+ # so that they can have different seek locations.
74
77
if self .filename not in open_files :
75
78
open_files [self .filename ] = open (self .filename , "r+b" )
76
79
f = open_files [self .filename ]
@@ -460,8 +463,9 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
460
463
if self .temp_file is None :
461
464
bar = None
462
465
# Distribute writing the tensors between multiple threads
463
- tensor_queue : Queue [TensorWriteInfo ] = Queue ()
466
+ tensor_queue : Queue [ThreadedTensorWriteInfo ] = Queue ()
464
467
468
+ # Initial file offsets before writing the tensor data
465
469
offsets : list [int ] = [fout .tell () for fout in self .fout ]
466
470
467
471
if progress :
@@ -472,6 +476,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
472
476
473
477
bar = tqdm (desc = "Writing" , total = total_bytes , unit = "byte" , unit_scale = True )
474
478
479
+ # Fill the tensor queue with all the pending tensor writes
475
480
for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
476
481
offset = offsets [i ]
477
482
@@ -484,7 +489,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
484
489
offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
485
490
padding = offset - (start_offset + nbytes )
486
491
tensor_queue .put (
487
- TensorWriteInfo (
492
+ ThreadedTensorWriteInfo (
488
493
filename = filename ,
489
494
offset = start_offset ,
490
495
post_pad = padding ,
@@ -496,12 +501,13 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
496
501
497
502
# Write tensors in parallel
498
503
# TODO: total tensor size limit for the running threads
499
- def write_tensors_from_thread (queue : Queue [TensorWriteInfo ]):
504
+ def write_tensors_from_thread (queue : Queue [ThreadedTensorWriteInfo ]):
505
+ # Opening the files only once per thread
500
506
open_files : dict [Path , BufferedWriter ] = {}
501
507
try :
502
- while t := queue .get_nowait ():
503
- t .write_chunk (open_files )
504
- del t
508
+ while tensor := queue .get_nowait ():
509
+ tensor .write_chunk (open_files )
510
+ del tensor
505
511
queue .task_done ()
506
512
except Empty :
507
513
pass
0 commit comments