Skip to content

Commit bd01667

Browse files
mikekgfbmalfet
authored andcommitted
add --parallel-prefill options & option validation (#368)
* add --parallel-prefill options, option validation, and refactor option validation * handle model is None for model validation * typo * move model compile opttions to generator args * typo * typo * typo * refactor * update eval
1 parent 5480866 commit bd01667

File tree

7 files changed

+99
-40
lines changed

7 files changed

+99
-40
lines changed

build/builder.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def from_args(cls, args): # -> BuilderArgs:
8888
)
8989
# The transformers config is keyed on the last section
9090
# of the name/path.
91-
params_table = model_config.transformer_params_key or model_config.name.split("/")[-1]
91+
params_table = (
92+
model_config.transformer_params_key or model_config.name.split("/")[-1]
93+
)
9294

9395
is_chat_model = False
9496
if args.is_chat_model:
@@ -143,6 +145,24 @@ class TokenizerArgs:
143145
is_sentencepiece: bool = True
144146
is_tiktoken: bool = False
145147

148+
def validate_model(
149+
self,
150+
model: Transformer,
151+
model_description: str = "model",
152+
):
153+
if model is None:
154+
return
155+
156+
use_tiktoken = model.config.use_tiktoken
157+
is_tiktoken = self.is_tiktoken
158+
159+
if use_tiktoken is None:
160+
model.config.use_tiktoken = is_tiktoken
161+
elif use_tiktoken != is_tiktoken:
162+
raise RuntimeError(
163+
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
164+
)
165+
146166
@classmethod
147167
def from_args(cls, args): # -> TokenizerArgs:
148168
is_sentencepiece = True
@@ -152,7 +172,11 @@ def from_args(cls, args): # -> TokenizerArgs:
152172
tokenizer_path = args.tokenizer_path
153173
elif args.model: # Using a named, well-known model
154174
model_config = resolve_model_config(args.model)
155-
tokenizer_path = Path(args.model_directory) / model_config.name / model_config.tokenizer_file
175+
tokenizer_path = (
176+
Path(args.model_directory)
177+
/ model_config.name
178+
/ model_config.tokenizer_file
179+
)
156180

157181
elif args.checkpoint_path:
158182
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
@@ -365,18 +389,6 @@ def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
365389
return "TikToken" if tiktoken else "SentencePiece"
366390

367391

368-
def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
369-
use_tiktoken = model.config.use_tiktoken
370-
is_tiktoken = tokenizer_args.is_tiktoken
371-
372-
if use_tiktoken is None:
373-
model.config.use_tiktoken = is_tiktoken
374-
elif use_tiktoken != is_tiktoken:
375-
raise RuntimeError(
376-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}"
377-
)
378-
379-
380392
def resolve_model_name(model: str) -> str:
381393
# If the provided model name is an alias, retrieve the full path.
382394
if model in model_aliases:

cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,12 @@ def add_arguments(parser):
131131
parser.add_argument(
132132
"--compile-prefill",
133133
action="store_true",
134-
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
134+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
135+
)
136+
parser.add_argument(
137+
"--parallel-prefill",
138+
action="store_true",
139+
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
135140
)
136141
parser.add_argument(
137142
"--profile",

download.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ def _download_hf_snapshot(
4242
else:
4343
raise e
4444

45-
4645
# Convert the model to the torchchat format.
4746
print(f"Converting {model_config.name} to torchchat format...")
48-
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)
47+
convert_hf_checkpoint(
48+
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
49+
)
4950

5051

5152
def _download_direct(
@@ -79,13 +80,15 @@ def download_and_convert(
7980
== ModelDistributionChannel.HuggingFaceSnapshot
8081
):
8182
_download_hf_snapshot(model_config, temp_dir, hf_token)
82-
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
83+
elif (
84+
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
85+
):
8386
_download_direct(model_config, temp_dir)
8487
else:
8588
raise RuntimeError(
8689
f"Unknown distribution channel {model_config.distribution_channel}."
8790
)
88-
91+
8992
# Move from the temporary directory to the intended location,
9093
# overwriting if necessary.
9194
if os.path.isdir(model_dir):

eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
_initialize_tokenizer,
1717
BuilderArgs,
1818
TokenizerArgs,
19-
validate_args,
2019
)
2120

2221
from build.model import Transformer
@@ -245,7 +244,7 @@ def main(args) -> None:
245244
quantize,
246245
tokenizer,
247246
)
248-
validate_args(model, tokenizer_args)
247+
tokenizer_args.validate_model(model)
249248

250249
if compile:
251250
assert not (

export_et_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from build.model import apply_rotary_emb, Attention
3-
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
3+
4+
# from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
45
from torch import nn
56

67

generate.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
_initialize_tokenizer,
2323
BuilderArgs,
2424
TokenizerArgs,
25-
validate_args,
2625
)
2726
from build.model import Transformer
2827
from build.utils import device_sync, set_precision
@@ -47,6 +46,31 @@ class GeneratorArgs:
4746
compile: bool = False
4847
compile_prefill: bool = False
4948
speculate_k: int = 5
49+
sequential_prefill: bool = True
50+
51+
def __post_init__(self):
52+
if self.compile_prefill and self.sequential_prefill:
53+
raise RuntimeError("prefill compilation requires parallel prefill")
54+
55+
def validate_build(
56+
self, builder_args: BuilderArgs, model_description: str = "model"
57+
):
58+
reason = ""
59+
model_type = ""
60+
if not self.sequential_prefill:
61+
reason = "parallel prefill"
62+
if self.compile_prefill:
63+
reason = "model compilation for prefill"
64+
if self.compile:
65+
reason = "model compilation"
66+
if builder_args.dso_path:
67+
model_type = "DSO"
68+
if builder_args.pte_path:
69+
model_type = "PTE"
70+
if model_type and reason:
71+
raise RuntimeError(
72+
f"cannot perform {reason} because a {model_type} {model_description} is used"
73+
)
5074

5175
@classmethod
5276
def from_args(cls, args): # -> GeneratorArgs:
@@ -62,6 +86,7 @@ def from_args(cls, args): # -> GeneratorArgs:
6286
compile=args.compile,
6387
compile_prefill=args.compile_prefill,
6488
speculate_k=args.speculate_k,
89+
sequential_prefill=not args.parallel_prefill,
6590
)
6691

6792

@@ -116,7 +141,6 @@ def prefill(
116141
logging.debug(f"x: {x}, input_pos: {input_pos}")
117142
width = x.size(1)
118143
assert input_pos.size(0) == width
119-
sequential_prefill = True
120144

121145
if sequential_prefill:
122146
for i in range(width):
@@ -244,6 +268,7 @@ def generate(
244268
chat_mode: bool,
245269
draft_model: Transformer,
246270
speculate_k: Optional[int] = 8,
271+
sequential_prefill=True,
247272
callback=lambda x: x,
248273
**sampling_kwargs,
249274
) -> torch.Tensor:
@@ -276,9 +301,21 @@ def generate(
276301
seq = empty
277302
input_pos = torch.arange(0, T, device=device, dtype=torch.int)
278303

279-
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
304+
next_token = prefill(
305+
model,
306+
prompt.view(1, -1),
307+
input_pos,
308+
sequential_prefill=sequential_prefill,
309+
**sampling_kwargs,
310+
)
280311
if is_speculative:
281-
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
312+
prefill(
313+
draft_model,
314+
prompt.view(1, -1),
315+
input_pos,
316+
sequential_prefill=sequential_prefill,
317+
**sampling_kwargs,
318+
)
282319
seq[T] = next_token
283320

284321
input_pos = torch.tensor([T], device=device, dtype=torch.int)
@@ -355,11 +392,9 @@ def _main(
355392
speculative_builder_args: BuilderArgs,
356393
tokenizer_args: TokenizerArgs,
357394
generator_args: GeneratorArgs,
358-
compile: bool = True,
359-
compile_prefill: bool = False,
360-
profile: Optional[Path] = None,
361-
quantize=None,
362-
draft_quantize=None,
395+
profile: Optional[Path],
396+
quantize,
397+
draft_quantize,
363398
) -> None:
364399
"""
365400
Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -398,7 +433,6 @@ def _main(
398433

399434
builder_args.setup_caches = False
400435
model = _initialize_model(builder_args, quantize, tokenizer)
401-
validate_args(model, tokenizer_args)
402436

403437
# will add a version of _initialize_model in future
404438
# (need additional args)
@@ -411,6 +445,11 @@ def _main(
411445
else:
412446
draft_model = None
413447

448+
tokenizer_args.validate_model(model)
449+
tokenizer_args.validate_model(draft_model, "draft model")
450+
generator_args.validate_build(builder_args)
451+
generator_args.validate_build(speculative_builder_args, "draft model")
452+
414453
encoded = encode_tokens(
415454
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
416455
)
@@ -423,7 +462,7 @@ def _main(
423462
for p in itertools.chain(model.parameters(), model.buffers())
424463
]
425464
)
426-
if compile:
465+
if generator_args.compile:
427466
if (
428467
is_speculative and builder_args.use_tp
429468
): # and ("cuda" in builder_args.device):
@@ -443,14 +482,14 @@ def _main(
443482
)
444483

445484
# Uncomment to squeeze more perf out of prefill
446-
if compile_prefill:
485+
if generator_args.compile_prefill:
447486
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
448487

449488
aggregate_metrics = {
450489
"tokens_per_sec": [],
451490
"accept_counts": [],
452491
}
453-
start = -1 if compile else 0
492+
start = -1 if generator_args.compile else 0
454493

455494
for i in range(start, generator_args.num_samples):
456495
device_sync(device=builder_args.device)
@@ -506,6 +545,7 @@ def callback(x):
506545
callback=callback,
507546
temperature=generator_args.temperature,
508547
top_k=generator_args.top_k,
548+
sequential_prefill=generator_args.sequential_prefill,
509549
)
510550
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
511551
if i == -1:
@@ -560,8 +600,6 @@ def main(args):
560600
speculative_builder_args,
561601
tokenizer_args,
562602
generator_args,
563-
args.compile,
564-
args.compile_prefill,
565603
args.profile,
566604
args.quantize,
567605
args.draft_quantize,

quantize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from __future__ import annotations
88

99
import json
10-
from functools import reduce
11-
from math import gcd
12-
from typing import Dict, Optional, Tuple
10+
11+
# from functools import reduce
12+
# from math import gcd
13+
from typing import Dict, Optional
1314

1415
import torch
1516
import torch.nn as nn

0 commit comments

Comments
 (0)