Skip to content

Commit d8bab9e

Browse files
committed
gguf-py : add more clarifying comments for multi-thread writes
1 parent 06e1d31 commit d8bab9e

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

gguf-py/gguf/gguf_writer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ class WriterState(Enum):
6363

6464

6565
@dataclass
66-
class TensorWriteInfo:
66+
class ThreadedTensorWriteInfo:
6767
filename: Path
6868
offset: int
6969
post_pad: int
7070
tensor: np.ndarray
71-
bar: Any | None
71+
bar: Any | None # optional tqdm progress bar
7272

7373
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
74+
# This is called from a thread pool,
75+
# and each thread should have its own file handle per output file
76+
# so that they can have different seek locations.
7477
if self.filename not in open_files:
7578
open_files[self.filename] = open(self.filename, "r+b")
7679
f = open_files[self.filename]
@@ -460,8 +463,9 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
460463
if self.temp_file is None:
461464
bar = None
462465
# Distribute writing the tensors between multiple threads
463-
tensor_queue: Queue[TensorWriteInfo] = Queue()
466+
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
464467

468+
# Initial file offsets before writing the tensor data
465469
offsets: list[int] = [fout.tell() for fout in self.fout]
466470

467471
if progress:
@@ -472,6 +476,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
472476

473477
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
474478

479+
# Fill the tensor queue with all the pending tensor writes
475480
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
476481
offset = offsets[i]
477482

@@ -484,7 +489,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
484489
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
485490
padding = offset - (start_offset + nbytes)
486491
tensor_queue.put(
487-
TensorWriteInfo(
492+
ThreadedTensorWriteInfo(
488493
filename=filename,
489494
offset=start_offset,
490495
post_pad=padding,
@@ -496,12 +501,13 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
496501

497502
# Write tensors in parallel
498503
# TODO: total tensor size limit for the running threads
499-
def write_tensors_from_thread(queue: Queue[TensorWriteInfo]):
504+
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
505+
# Opening the files only once per thread
500506
open_files: dict[Path, BufferedWriter] = {}
501507
try:
502-
while t := queue.get_nowait():
503-
t.write_chunk(open_files)
504-
del t
508+
while tensor := queue.get_nowait():
509+
tensor.write_chunk(open_files)
510+
del tensor
505511
queue.task_done()
506512
except Empty:
507513
pass

0 commit comments

Comments
 (0)