Skip to content

Commit 757b224

Browse files
committed
Add attention_backend to let user choose
1 parent 654bb03 commit 757b224

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

torchchat/cli/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class BuilderArgs:
6969
prefill_possible: bool = False
7070
dynamic_shapes: bool = False
7171
max_seq_length: Optional[int] = None
72+
attention_backend: str = "math"
7273

7374
def __post_init__(self):
7475
if self.device is None:
@@ -202,6 +203,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
202203
is_chat_model=is_chat_model,
203204
dynamic_shapes=getattr(args, "dynamic_shapes", False),
204205
max_seq_length=getattr(args, "max_seq_length", None),
206+
attention_backend=args.attention_backend,
205207
)
206208

207209
@classmethod

torchchat/cli/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ def _add_model_config_args(parser, verb: str) -> None:
179179
choices=["fast", "cpu", "cuda", "mps"],
180180
help="Hardware device to use. Options: fast, cpu, cuda, mps",
181181
)
182+
model_config_parser.add_argument(
183+
"--attention-backend",
184+
type=str,
185+
default="math",
186+
choices=["math", "flash_attention", "efficient_attention", "cudnn_attention"],
187+
help="SDPBackend to use. Options: MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION",
188+
)
182189

183190

184191
# Add CLI Args representing output paths of exported model files

torchchat/generate.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,22 @@ def decode_n_tokens(
531531
callback=lambda _: _,
532532
eos_token_id: int = 2,
533533
eot_id: Optional[int] = None,
534+
attention_backend: str = "math",
534535
**sampling_kwargs,
535536
):
536537
new_tokens, new_probs = [], []
537538
encountered_eos = False
539+
sdp_backend_dict = {
540+
'math': torch.nn.attention.SDPBackend.MATH,
541+
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
542+
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
543+
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
544+
}
538545
for _i in range(
539546
num_new_tokens - 1
540547
): # -1 to save space to run an EoS if dont generate it naturally
541548
# Actually better for Inductor to codegen attention here
542-
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
549+
with torch.nn.attention.sdpa_kernel([sdp_backend_dict[attention_backend]]):
543550

544551
out_token = cur_token.clone()
545552
next_token, next_prob = self.decode_one_token(
@@ -683,6 +690,7 @@ def generate(
683690
sequential_prefill=True,
684691
callback=lambda x: x,
685692
max_seq_length: int,
693+
attention_backend: str = "math",
686694
seed: Optional[int] = None,
687695
**sampling_kwargs,
688696
) -> torch.Tensor:
@@ -799,6 +807,7 @@ def generate(
799807
if self.is_llama3_model
800808
else None
801809
),
810+
attention_backend=attention_backend,
802811
**sampling_kwargs,
803812
):
804813
generated_tokens.append(generated_token.view(-1))
@@ -1170,6 +1179,10 @@ def callback(x, *, done_generating=False):
11701179
prof = torch.profiler.profile()
11711180
t0 = time.perf_counter()
11721181
num_tokens_generated = 0
1182+
if self.builder_args.device == "cpu" and (self.builder_args.attention_backend == "efficient_attention"
1183+
or self.builder_args.attention_backend == "cudnn_attention"):
1184+
print(f"Warning: {self.builder_args.attention_backend} is not supported on CPU. Using math instead.")
1185+
self.builder_args.attention_backend = "math"
11731186
with prof:
11741187
generator_func = self.generate(
11751188
self.model,
@@ -1186,6 +1199,7 @@ def callback(x, *, done_generating=False):
11861199
start_pos=start_pos,
11871200
skip_cache_setup=not is_first_sample,
11881201
max_seq_length=max_seq_length,
1202+
attention_backend=self.builder_args.attention_backend,
11891203
)
11901204
for token_tensor, metrics in generator_func:
11911205
if token_tensor is not None:

0 commit comments

Comments
 (0)