Skip to content

perform parallel prefill when possible #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False
is_chat_model: bool = False
prefill_possible: bool = False

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -68,6 +69,8 @@ def __post_init__(self):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)
if not (self.dso_path) and not (self.pte_path):
self.prefill_possible = True

@classmethod
def from_args(cls, args): # -> BuilderArgs:
Expand Down
6 changes: 3 additions & 3 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def _add_arguments_common(parser):
parser.add_argument(
"--compile-prefill",
action="store_true",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
)
parser.add_argument(
"--parallel-prefill",
"--sequential-prefill",
action="store_true",
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
help="Whether to perform prefill sequentially. Only used for model debug.",
)
parser.add_argument(
"--profile",
Expand Down
2 changes: 1 addition & 1 deletion export_et_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
super().__init__()

dtype = torch.float

# This is flipped around from what is in build.model's KVCache
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer(
Expand Down
22 changes: 15 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class GeneratorArgs:
compile: bool = False
compile_prefill: bool = False
speculate_k: int = 5
sequential_prefill: bool = True
sequential_prefill: bool = False

def __post_init__(self):
if self.compile_prefill and self.sequential_prefill:
Expand Down Expand Up @@ -104,6 +104,10 @@ def validate_build(

@classmethod
def from_args(cls, args):
sequential_prefill = (
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
)

return cls(
prompt=args.prompt,
encoded_prompt=None,
Expand All @@ -116,7 +120,7 @@ def from_args(cls, args):
compile=args.compile,
compile_prefill=args.compile_prefill,
speculate_k=args.speculate_k,
sequential_prefill=not args.parallel_prefill,
sequential_prefill=sequential_prefill,
)


Expand Down Expand Up @@ -151,10 +155,10 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
def sample(
logits, need_probs: bool, temperature: float = 1.0, top_k: Optional[int] = None
):
if temperature == 0 and not need_probs:
_, idx_next = torch.topk(logits, k=1, dim=-1)
idx_next = idx_next.squeeze(dim=(0, 1))
return (idx_next, None)
# if temperature == 0 and not need_probs:
# _, idx_next = torch.topk(logits, k=1, dim=-1)
# idx_next = idx_next.squeeze(dim=(0, 1))
# return (idx_next, None)
probs = logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
Expand All @@ -175,12 +179,14 @@ def prefill(
if sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
else:
# input_pos: [B, S]
logits = model(x, input_pos)
# print(f"logits {logits.shape}")

# print(f"x: {x},\n input_pos: {input_pos}\n")
return sample(logits, need_probs=False, **sampling_kwargs)[0]


Expand All @@ -194,6 +200,7 @@ def decode_one_token(
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
# print(f"x: {x},\n input_pos: {input_pos}\n")
return sample(logits, need_probs=need_probs, **sampling_kwargs)


Expand Down Expand Up @@ -379,6 +386,7 @@ def generate(
sequential_prefill=sequential_prefill,
**sampling_kwargs,
)
# print(f"sizes: {T} {seq[T].shape} {seq.shape} {next_token.shape}")
seq[T] = next_token
callback(next_token.clone().view(-1))

Expand Down
181 changes: 0 additions & 181 deletions parking_lot/quantized_ops.py

This file was deleted.

1 change: 0 additions & 1 deletion qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
@classmethod
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

1 change: 0 additions & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
state_dict_device,
use_et_backend,
)
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding

from qops import (
LinearInt4 as WeightOnlyInt4Linear,
Expand Down