Skip to content

Commit 933b520

Browse files
authored
Revert "add --parallel-prefill options & option validation (#368)"
This reverts commit ab03af2.
1 parent ab03af2 commit 933b520

File tree

7 files changed

+40
-99
lines changed

7 files changed

+40
-99
lines changed

build/builder.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def from_args(cls, args): # -> BuilderArgs:
8888
)
8989
# The transformers config is keyed on the last section
9090
# of the name/path.
91-
params_table = (
92-
model_config.transformer_params_key or model_config.name.split("/")[-1]
93-
)
91+
params_table = model_config.transformer_params_key or model_config.name.split("/")[-1]
9492

9593
is_chat_model = False
9694
if args.is_chat_model:
@@ -145,24 +143,6 @@ class TokenizerArgs:
145143
is_sentencepiece: bool = True
146144
is_tiktoken: bool = False
147145

148-
def validate_model(
149-
self,
150-
model: Transformer,
151-
model_description: str = "model",
152-
):
153-
if model is None:
154-
return
155-
156-
use_tiktoken = model.config.use_tiktoken
157-
is_tiktoken = self.is_tiktoken
158-
159-
if use_tiktoken is None:
160-
model.config.use_tiktoken = is_tiktoken
161-
elif use_tiktoken != is_tiktoken:
162-
raise RuntimeError(
163-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
164-
)
165-
166146
@classmethod
167147
def from_args(cls, args): # -> TokenizerArgs:
168148
is_sentencepiece = True
@@ -172,11 +152,7 @@ def from_args(cls, args): # -> TokenizerArgs:
172152
tokenizer_path = args.tokenizer_path
173153
elif args.model: # Using a named, well-known model
174154
model_config = resolve_model_config(args.model)
175-
tokenizer_path = (
176-
Path(args.model_directory)
177-
/ model_config.name
178-
/ model_config.tokenizer_file
179-
)
155+
tokenizer_path = Path(args.model_directory) / model_config.name / model_config.tokenizer_file
180156

181157
elif args.checkpoint_path:
182158
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
@@ -389,6 +365,18 @@ def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
389365
return "TikToken" if tiktoken else "SentencePiece"
390366

391367

368+
def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
369+
use_tiktoken = model.config.use_tiktoken
370+
is_tiktoken = tokenizer_args.is_tiktoken
371+
372+
if use_tiktoken is None:
373+
model.config.use_tiktoken = is_tiktoken
374+
elif use_tiktoken != is_tiktoken:
375+
raise RuntimeError(
376+
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}"
377+
)
378+
379+
392380
def resolve_model_name(model: str) -> str:
393381
# If the provided model name is an alias, retrieve the full path.
394382
if model in model_aliases:

cli.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,7 @@ def add_arguments(parser):
131131
parser.add_argument(
132132
"--compile-prefill",
133133
action="store_true",
134-
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
135-
)
136-
parser.add_argument(
137-
"--parallel-prefill",
138-
action="store_true",
139-
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
134+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
140135
)
141136
parser.add_argument(
142137
"--profile",

download.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def _download_hf_snapshot(
4242
else:
4343
raise e
4444

45+
4546
# Convert the model to the torchchat format.
4647
print(f"Converting {model_config.name} to torchchat format...")
47-
convert_hf_checkpoint(
48-
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
49-
)
48+
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)
5049

5150

5251
def _download_direct(
@@ -80,15 +79,13 @@ def download_and_convert(
8079
== ModelDistributionChannel.HuggingFaceSnapshot
8180
):
8281
_download_hf_snapshot(model_config, temp_dir, hf_token)
83-
elif (
84-
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
85-
):
82+
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
8683
_download_direct(model_config, temp_dir)
8784
else:
8885
raise RuntimeError(
8986
f"Unknown distribution channel {model_config.distribution_channel}."
9087
)
91-
88+
9289
# Move from the temporary directory to the intended location,
9390
# overwriting if necessary.
9491
if os.path.isdir(model_dir):

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_initialize_tokenizer,
1717
BuilderArgs,
1818
TokenizerArgs,
19+
validate_args,
1920
)
2021

2122
from build.model import Transformer
@@ -244,7 +245,7 @@ def main(args) -> None:
244245
quantize,
245246
tokenizer,
246247
)
247-
tokenizer_args.validate_model(model)
248+
validate_args(model, tokenizer_args)
248249

249250
if compile:
250251
assert not (

export_et_util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from build.model import apply_rotary_emb, Attention
3-
4-
# from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
3+
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
54
from torch import nn
65

76

generate.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_initialize_tokenizer,
2323
BuilderArgs,
2424
TokenizerArgs,
25+
validate_args,
2526
)
2627
from build.model import Transformer
2728
from build.utils import device_sync, set_precision
@@ -46,31 +47,6 @@ class GeneratorArgs:
4647
compile: bool = False
4748
compile_prefill: bool = False
4849
speculate_k: int = 5
49-
sequential_prefill: bool = True
50-
51-
def __post_init__(self):
52-
if self.compile_prefill and self.sequential_prefill:
53-
raise RuntimeError("prefill compilation requires parallel prefill")
54-
55-
def validate_build(
56-
self, builder_args: BuilderArgs, model_description: str = "model"
57-
):
58-
reason = ""
59-
model_type = ""
60-
if not self.sequential_prefill:
61-
reason = "parallel prefill"
62-
if self.compile_prefill:
63-
reason = "model compilation for prefill"
64-
if self.compile:
65-
reason = "model compilation"
66-
if builder_args.dso_path:
67-
model_type = "DSO"
68-
if builder_args.pte_path:
69-
model_type = "PTE"
70-
if model_type and reason:
71-
raise RuntimeError(
72-
f"cannot perform {reason} because a {model_type} {model_description} is used"
73-
)
7450

7551
@classmethod
7652
def from_args(cls, args): # -> GeneratorArgs:
@@ -86,7 +62,6 @@ def from_args(cls, args): # -> GeneratorArgs:
8662
compile=args.compile,
8763
compile_prefill=args.compile_prefill,
8864
speculate_k=args.speculate_k,
89-
sequential_prefill=not args.parallel_prefill,
9065
)
9166

9267

@@ -141,6 +116,7 @@ def prefill(
141116
logging.debug(f"x: {x}, input_pos: {input_pos}")
142117
width = x.size(1)
143118
assert input_pos.size(0) == width
119+
sequential_prefill = True
144120

145121
if sequential_prefill:
146122
for i in range(width):
@@ -268,7 +244,6 @@ def generate(
268244
chat_mode: bool,
269245
draft_model: Transformer,
270246
speculate_k: Optional[int] = 8,
271-
sequential_prefill=True,
272247
callback=lambda x: x,
273248
**sampling_kwargs,
274249
) -> torch.Tensor:
@@ -301,21 +276,9 @@ def generate(
301276
seq = empty
302277
input_pos = torch.arange(0, T, device=device, dtype=torch.int)
303278

304-
next_token = prefill(
305-
model,
306-
prompt.view(1, -1),
307-
input_pos,
308-
sequential_prefill=sequential_prefill,
309-
**sampling_kwargs,
310-
)
279+
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
311280
if is_speculative:
312-
prefill(
313-
draft_model,
314-
prompt.view(1, -1),
315-
input_pos,
316-
sequential_prefill=sequential_prefill,
317-
**sampling_kwargs,
318-
)
281+
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
319282
seq[T] = next_token
320283

321284
input_pos = torch.tensor([T], device=device, dtype=torch.int)
@@ -392,9 +355,11 @@ def _main(
392355
speculative_builder_args: BuilderArgs,
393356
tokenizer_args: TokenizerArgs,
394357
generator_args: GeneratorArgs,
395-
profile: Optional[Path],
396-
quantize,
397-
draft_quantize,
358+
compile: bool = True,
359+
compile_prefill: bool = False,
360+
profile: Optional[Path] = None,
361+
quantize=None,
362+
draft_quantize=None,
398363
) -> None:
399364
"""
400365
Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -433,6 +398,7 @@ def _main(
433398

434399
builder_args.setup_caches = False
435400
model = _initialize_model(builder_args, quantize, tokenizer)
401+
validate_args(model, tokenizer_args)
436402

437403
# will add a version of _initialize_model in future
438404
# (need additional args)
@@ -445,11 +411,6 @@ def _main(
445411
else:
446412
draft_model = None
447413

448-
tokenizer_args.validate_model(model)
449-
tokenizer_args.validate_model(draft_model, "draft model")
450-
generator_args.validate_build(builder_args)
451-
generator_args.validate_build(speculative_builder_args, "draft model")
452-
453414
encoded = encode_tokens(
454415
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
455416
)
@@ -462,7 +423,7 @@ def _main(
462423
for p in itertools.chain(model.parameters(), model.buffers())
463424
]
464425
)
465-
if generator_args.compile:
426+
if compile:
466427
if (
467428
is_speculative and builder_args.use_tp
468429
): # and ("cuda" in builder_args.device):
@@ -482,14 +443,14 @@ def _main(
482443
)
483444

484445
# Uncomment to squeeze more perf out of prefill
485-
if generator_args.compile_prefill:
446+
if compile_prefill:
486447
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
487448

488449
aggregate_metrics = {
489450
"tokens_per_sec": [],
490451
"accept_counts": [],
491452
}
492-
start = -1 if generator_args.compile else 0
453+
start = -1 if compile else 0
493454

494455
for i in range(start, generator_args.num_samples):
495456
device_sync(device=builder_args.device)
@@ -545,7 +506,6 @@ def callback(x):
545506
callback=callback,
546507
temperature=generator_args.temperature,
547508
top_k=generator_args.top_k,
548-
sequential_prefill=generator_args.sequential_prefill,
549509
)
550510
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
551511
if i == -1:
@@ -600,6 +560,8 @@ def main(args):
600560
speculative_builder_args,
601561
tokenizer_args,
602562
generator_args,
563+
args.compile,
564+
args.compile_prefill,
603565
args.profile,
604566
args.quantize,
605567
args.draft_quantize,

quantize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from __future__ import annotations
88

99
import json
10-
11-
# from functools import reduce
12-
# from math import gcd
13-
from typing import Dict, Optional
10+
from functools import reduce
11+
from math import gcd
12+
from typing import Dict, Optional, Tuple
1413

1514
import torch
1615
import torch.nn as nn

0 commit comments

Comments
 (0)