Skip to content

Commit 655990c

Browse files
committed
Add --concurrency option
Minor improvements to help text Clean up bounded_parallel_map function a bit
1 parent 0ddeeba commit 655990c

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

convert.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ARCH=gguf.MODEL_ARCH.LLAMA
4040
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
4141

42+
DEFAULT_CONCURRENCY = 8
4243
#
4344
# data types
4445
#
@@ -717,21 +718,21 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
717718
with factory(max_workers = max_workers) as executor:
718719
futures: List[concurrent.futures.Future[Out]] = []
719720
done = False
720-
for i in range(concurrency):
721+
for _ in range(concurrency):
721722
try:
722-
nexti = next(iterable)
723+
futures.append(executor.submit(func, next(iterable)))
723724
except StopIteration:
725+
done = True
724726
break
725-
futures.append(executor.submit(func, nexti))
726-
while not done or futures:
727+
728+
while futures:
727729
result = futures.pop(0).result()
728-
while len(futures) < concurrency:
730+
while not done and len(futures) < concurrency:
729731
try:
730-
nexti = next(iterable)
732+
futures.append(executor.submit(func, next(iterable)))
731733
except StopIteration:
732734
done = True
733735
break
734-
futures.append(executor.submit(func, nexti))
735736
yield result
736737

737738
def check_vocab_size(params: Params, vocab: Vocab) -> None:
@@ -850,13 +851,13 @@ def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray):
850851
return (lazy_tensor.data_type, tensor.ndarray)
851852

852853
@staticmethod
853-
def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray:
854+
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
854855
if item[0] == DT_Q8_0:
855856
return quantize_array_q8_0(item[1])
856857
return item[1]
857858

858859
@staticmethod
859-
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None:
860+
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
860861
check_vocab_size(params, vocab)
861862

862863
of = OutputFile(fname_out)
@@ -873,11 +874,11 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM
873874
of.write_tensor_info()
874875

875876
# tensor data
876-
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8)
877+
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
877878
if ftype == GGMLFileType.MostlyQ8_0:
878-
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor)
879+
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
879880
else:
880-
ndarrays = map(OutputFile.maybe_do_quant, ndarrays)
881+
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays)
881882

882883
start = time.time()
883884
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
@@ -1073,12 +1074,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
10731074
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
10741075
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
10751076
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
1076-
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)")
1077+
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
10771078
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
10781079
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
10791080
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
10801081
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
10811082
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
1083+
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
10821084
args = parser.parse_args(args_in)
10831085

10841086
if args.dump_single:
@@ -1132,7 +1134,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
11321134
params.ftype = ftype
11331135
print(f"Writing {outfile}, format {ftype}")
11341136

1135-
OutputFile.write_all(outfile, ftype, params, model, vocab)
1137+
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
11361138
print(f"Wrote {outfile}")
11371139

11381140

0 commit comments

Comments
 (0)