Skip to content

Revert "add --parallel-prefill options & option validation" #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def from_args(cls, args): # -> BuilderArgs:
)
# The transformers config is keyed on the last section
# of the name/path.
params_table = (
model_config.transformer_params_key or model_config.name.split("/")[-1]
)
params_table = model_config.transformer_params_key or model_config.name.split("/")[-1]

is_chat_model = False
if args.is_chat_model:
Expand Down Expand Up @@ -145,24 +143,6 @@ class TokenizerArgs:
is_sentencepiece: bool = True
is_tiktoken: bool = False

def validate_model(
self,
model: Transformer,
model_description: str = "model",
):
if model is None:
return

use_tiktoken = model.config.use_tiktoken
is_tiktoken = self.is_tiktoken

if use_tiktoken is None:
model.config.use_tiktoken = is_tiktoken
elif use_tiktoken != is_tiktoken:
raise RuntimeError(
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
)

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
is_sentencepiece = True
Expand All @@ -172,11 +152,7 @@ def from_args(cls, args): # -> TokenizerArgs:
tokenizer_path = args.tokenizer_path
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = (
Path(args.model_directory)
/ model_config.name
/ model_config.tokenizer_file
)
tokenizer_path = Path(args.model_directory) / model_config.name / model_config.tokenizer_file

elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
Expand Down Expand Up @@ -389,6 +365,18 @@ def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
return "TikToken" if tiktoken else "SentencePiece"


def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
use_tiktoken = model.config.use_tiktoken
is_tiktoken = tokenizer_args.is_tiktoken

if use_tiktoken is None:
model.config.use_tiktoken = is_tiktoken
elif use_tiktoken != is_tiktoken:
raise RuntimeError(
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}"
)


def resolve_model_name(model: str) -> str:
# If the provided model name is an alias, retrieve the full path.
if model in model_aliases:
Expand Down
7 changes: 1 addition & 6 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,7 @@ def add_arguments(parser):
parser.add_argument(
"--compile-prefill",
action="store_true",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
)
parser.add_argument(
"--parallel-prefill",
action="store_true",
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
)
parser.add_argument(
"--profile",
Expand Down
11 changes: 4 additions & 7 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def _download_hf_snapshot(
else:
raise e


# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...")
convert_hf_checkpoint(
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
)
convert_hf_checkpoint(model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True)


def _download_direct(
Expand Down Expand Up @@ -80,15 +79,13 @@ def download_and_convert(
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, temp_dir, hf_token)
elif (
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
):
elif model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, temp_dir)
else:
raise RuntimeError(
f"Unknown distribution channel {model_config.distribution_channel}."
)

# Move from the temporary directory to the intended location,
# overwriting if necessary.
if os.path.isdir(model_dir):
Expand Down
3 changes: 2 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
validate_args,
)

from build.model import Transformer
Expand Down Expand Up @@ -244,7 +245,7 @@ def main(args) -> None:
quantize,
tokenizer,
)
tokenizer_args.validate_model(model)
validate_args(model, tokenizer_args)

if compile:
assert not (
Expand Down
3 changes: 1 addition & 2 deletions export_et_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from build.model import apply_rotary_emb, Attention

# from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache
from torch import nn


Expand Down
68 changes: 15 additions & 53 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
validate_args,
)
from build.model import Transformer
from build.utils import device_sync, set_precision
Expand All @@ -46,31 +47,6 @@ class GeneratorArgs:
compile: bool = False
compile_prefill: bool = False
speculate_k: int = 5
sequential_prefill: bool = True

def __post_init__(self):
if self.compile_prefill and self.sequential_prefill:
raise RuntimeError("prefill compilation requires parallel prefill")

def validate_build(
self, builder_args: BuilderArgs, model_description: str = "model"
):
reason = ""
model_type = ""
if not self.sequential_prefill:
reason = "parallel prefill"
if self.compile_prefill:
reason = "model compilation for prefill"
if self.compile:
reason = "model compilation"
if builder_args.dso_path:
model_type = "DSO"
if builder_args.pte_path:
model_type = "PTE"
if model_type and reason:
raise RuntimeError(
f"cannot perform {reason} because a {model_type} {model_description} is used"
)

@classmethod
def from_args(cls, args): # -> GeneratorArgs:
Expand All @@ -86,7 +62,6 @@ def from_args(cls, args): # -> GeneratorArgs:
compile=args.compile,
compile_prefill=args.compile_prefill,
speculate_k=args.speculate_k,
sequential_prefill=not args.parallel_prefill,
)


Expand Down Expand Up @@ -141,6 +116,7 @@ def prefill(
logging.debug(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
assert input_pos.size(0) == width
sequential_prefill = True

if sequential_prefill:
for i in range(width):
Expand Down Expand Up @@ -268,7 +244,6 @@ def generate(
chat_mode: bool,
draft_model: Transformer,
speculate_k: Optional[int] = 8,
sequential_prefill=True,
callback=lambda x: x,
**sampling_kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -301,21 +276,9 @@ def generate(
seq = empty
input_pos = torch.arange(0, T, device=device, dtype=torch.int)

next_token = prefill(
model,
prompt.view(1, -1),
input_pos,
sequential_prefill=sequential_prefill,
**sampling_kwargs,
)
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
if is_speculative:
prefill(
draft_model,
prompt.view(1, -1),
input_pos,
sequential_prefill=sequential_prefill,
**sampling_kwargs,
)
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
seq[T] = next_token

input_pos = torch.tensor([T], device=device, dtype=torch.int)
Expand Down Expand Up @@ -392,9 +355,11 @@ def _main(
speculative_builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
generator_args: GeneratorArgs,
profile: Optional[Path],
quantize,
draft_quantize,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
quantize=None,
draft_quantize=None,
) -> None:
"""
Generates text samples based on a pre-trained Transformer model and tokenizer.
Expand Down Expand Up @@ -433,6 +398,7 @@ def _main(

builder_args.setup_caches = False
model = _initialize_model(builder_args, quantize, tokenizer)
validate_args(model, tokenizer_args)

# will add a version of _initialize_model in future
# (need additional args)
Expand All @@ -445,11 +411,6 @@ def _main(
else:
draft_model = None

tokenizer_args.validate_model(model)
tokenizer_args.validate_model(draft_model, "draft model")
generator_args.validate_build(builder_args)
generator_args.validate_build(speculative_builder_args, "draft model")

encoded = encode_tokens(
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
)
Expand All @@ -462,7 +423,7 @@ def _main(
for p in itertools.chain(model.parameters(), model.buffers())
]
)
if generator_args.compile:
if compile:
if (
is_speculative and builder_args.use_tp
): # and ("cuda" in builder_args.device):
Expand All @@ -482,14 +443,14 @@ def _main(
)

# Uncomment to squeeze more perf out of prefill
if generator_args.compile_prefill:
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

aggregate_metrics = {
"tokens_per_sec": [],
"accept_counts": [],
}
start = -1 if generator_args.compile else 0
start = -1 if compile else 0

for i in range(start, generator_args.num_samples):
device_sync(device=builder_args.device)
Expand Down Expand Up @@ -545,7 +506,6 @@ def callback(x):
callback=callback,
temperature=generator_args.temperature,
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
)
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
if i == -1:
Expand Down Expand Up @@ -600,6 +560,8 @@ def main(args):
speculative_builder_args,
tokenizer_args,
generator_args,
args.compile,
args.compile_prefill,
args.profile,
args.quantize,
args.draft_quantize,
Expand Down
7 changes: 3 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from __future__ import annotations

import json

# from functools import reduce
# from math import gcd
from typing import Dict, Optional
from functools import reduce
from math import gcd
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
Expand Down