Skip to content

Commit 9d10748

Browse files
mikekgfbmalfet
authored andcommitted
perform parallel prefill when possible (#568)
* perform parallel prefill when possible * typo * disable hack * remove print * remove debug messages which prevent export * fixes
1 parent 8580712 commit 9d10748

File tree

7 files changed

+22
-194
lines changed

7 files changed

+22
-194
lines changed

build/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class BuilderArgs:
3737
setup_caches: bool = False
3838
use_tp: bool = False
3939
is_chat_model: bool = False
40+
prefill_possible: bool = False
4041

4142
def __post_init__(self):
4243
if self.device is None:
@@ -68,6 +69,8 @@ def __post_init__(self):
6869
print(
6970
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
7071
)
72+
if not (self.dso_path) and not (self.pte_path):
73+
self.prefill_possible = True
7174

7275
@classmethod
7376
def from_args(cls, args): # -> BuilderArgs:

cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def _add_arguments_common(parser):
136136
parser.add_argument(
137137
"--compile-prefill",
138138
action="store_true",
139-
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
139+
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
140140
)
141141
parser.add_argument(
142-
"--parallel-prefill",
142+
"--sequential-prefill",
143143
action="store_true",
144-
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
144+
help="Whether to perform prefill sequentially. Only used for model debug.",
145145
)
146146
parser.add_argument(
147147
"--profile",

export_et_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
1010
super().__init__()
1111

1212
dtype = torch.float
13-
13+
1414
# This is flipped around from what is in build.model's KVCache
1515
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
1616
self.register_buffer(

generate.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class GeneratorArgs:
7676
compile: bool = False
7777
compile_prefill: bool = False
7878
speculate_k: int = 5
79-
sequential_prefill: bool = True
79+
sequential_prefill: bool = False
8080

8181
def __post_init__(self):
8282
if self.compile_prefill and self.sequential_prefill:
@@ -104,6 +104,10 @@ def validate_build(
104104

105105
@classmethod
106106
def from_args(cls, args):
107+
sequential_prefill = (
108+
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
109+
)
110+
107111
return cls(
108112
prompt=args.prompt,
109113
encoded_prompt=None,
@@ -116,7 +120,7 @@ def from_args(cls, args):
116120
compile=args.compile,
117121
compile_prefill=args.compile_prefill,
118122
speculate_k=args.speculate_k,
119-
sequential_prefill=not args.parallel_prefill,
123+
sequential_prefill=sequential_prefill,
120124
)
121125

122126

@@ -151,10 +155,10 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
151155
def sample(
152156
logits, need_probs: bool, temperature: float = 1.0, top_k: Optional[int] = None
153157
):
154-
if temperature == 0 and not need_probs:
155-
_, idx_next = torch.topk(logits, k=1, dim=-1)
156-
idx_next = idx_next.squeeze(dim=(0, 1))
157-
return (idx_next, None)
158+
# if temperature == 0 and not need_probs:
159+
# _, idx_next = torch.topk(logits, k=1, dim=-1)
160+
# idx_next = idx_next.squeeze(dim=(0, 1))
161+
# return (idx_next, None)
158162
probs = logits_to_probs(logits[0, -1], temperature, top_k)
159163
idx_next = multinomial_sample_one_no_sync(probs)
160164
return idx_next, probs
@@ -175,12 +179,14 @@ def prefill(
175179
if sequential_prefill:
176180
for i in range(width):
177181
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
178-
logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
182+
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
179183
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
180184
else:
181185
# input_pos: [B, S]
182186
logits = model(x, input_pos)
187+
# print(f"logits {logits.shape}")
183188

189+
# print(f"x: {x},\n input_pos: {input_pos}\n")
184190
return sample(logits, need_probs=False, **sampling_kwargs)[0]
185191

186192

@@ -194,6 +200,7 @@ def decode_one_token(
194200
# input_pos: [B, 1]
195201
assert input_pos.shape[-1] == 1
196202
logits = model(x, input_pos)
203+
# print(f"x: {x},\n input_pos: {input_pos}\n")
197204
return sample(logits, need_probs=need_probs, **sampling_kwargs)
198205

199206

@@ -379,6 +386,7 @@ def generate(
379386
sequential_prefill=sequential_prefill,
380387
**sampling_kwargs,
381388
)
389+
# print(f"sizes: {T} {seq[T].shape} {seq.shape} {next_token.shape}")
382390
seq[T] = next_token
383391
callback(next_token.clone().view(-1))
384392

parking_lot/quantized_ops.py

Lines changed: 0 additions & 181 deletions
This file was deleted.

qops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
305305
@classmethod
306306
def _check_k(cls, *, k, groupsize=1, inner_k_tiles=1):
307307
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
308-

quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
state_dict_device,
2424
use_et_backend,
2525
)
26-
from qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
2726

2827
from qops import (
2928
LinearInt4 as WeightOnlyInt4Linear,

0 commit comments

Comments
 (0)