Skip to content

Commit 7d0d06d

Browse files
authored
Replace scaled_dot_product_attention lowering pass with decomposition (#3296)
1 parent bed5d37 commit 7d0d06d

File tree

8 files changed

+477
-622
lines changed

8 files changed

+477
-622
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,38 +2750,6 @@ def aten_ops_max_pool(
27502750
)
27512751

27522752

2753-
def attention_validator(
2754-
node: Node, settings: Optional[CompilationSettings] = None
2755-
) -> bool:
2756-
# Currently, `attn_mask` is not supported
2757-
return args_bounds_check(node.args, 3) is None
2758-
2759-
2760-
@dynamo_tensorrt_converter(
2761-
torch.nn.functional.scaled_dot_product_attention,
2762-
capability_validator=attention_validator,
2763-
supports_dynamic_shapes=True,
2764-
)
2765-
def tensorrt_scaled_dot_product_attention(
2766-
ctx: ConversionContext,
2767-
target: Target,
2768-
args: Tuple[Argument, ...],
2769-
kwargs: Dict[str, Argument],
2770-
name: str,
2771-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2772-
return impl.attention.scaled_dot_product_attention(
2773-
ctx,
2774-
target,
2775-
SourceIR.TORCHTRT_LOWERED,
2776-
name,
2777-
args[0],
2778-
args[1],
2779-
args[2],
2780-
args_bounds_check(args, 5, False),
2781-
kwargs.get("scale", None),
2782-
)
2783-
2784-
27852753
@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
27862754
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
27872755
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
activation,
55
addmm,
66
arange,
7-
attention,
87
cast,
98
cat,
109
condition,

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 0 additions & 165 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from enum import Enum, auto
3-
from typing import Any, Callable, Dict, List, Optional
3+
from typing import Any, Callable, Dict, List, Optional, Tuple
44

55
import torch
66
from torch._decomp import register_decomposition
@@ -423,6 +423,132 @@ def instance_norm_decomposition(
423423
)
424424

425425

426+
@register_torch_trt_decomposition(
427+
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
428+
)
429+
def scaled_dot_product_attention_decomposition(
430+
query: torch.Tensor,
431+
key: torch.Tensor,
432+
value: torch.Tensor,
433+
attn_mask: Optional[torch.Tensor] = None,
434+
dropout_p: float = 0.0,
435+
is_causal: bool = False,
436+
*,
437+
scale: Optional[float] = None,
438+
enable_gqa: bool = False,
439+
) -> torch.Tensor:
440+
L, S = query.size(-2), key.size(-2)
441+
device = query.device
442+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
443+
444+
if is_causal:
445+
assert attn_mask is None, "attn_mask must be None when is_causal=True"
446+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
447+
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
448+
449+
if attn_mask is not None:
450+
if attn_mask.dtype == torch.bool:
451+
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
452+
else:
453+
attn_bias = attn_mask + attn_bias
454+
455+
if enable_gqa:
456+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
457+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
458+
459+
attn_weight = query @ key.transpose(-2, -1)
460+
461+
if scale is None:
462+
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
463+
attn_weight = attn_weight / scale
464+
else:
465+
attn_weight = attn_weight * scale
466+
467+
attn_weight = attn_weight + attn_bias
468+
attn_weight = torch.softmax(attn_weight, dim=-1)
469+
return attn_weight @ value
470+
471+
472+
@register_torch_trt_decomposition(
473+
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
474+
)
475+
def scaled_dot_product_flash_attention_decomposition(
476+
query: torch.Tensor,
477+
key: torch.Tensor,
478+
value: torch.Tensor,
479+
dropout_p: float = 0.0,
480+
is_causal: bool = False,
481+
return_debug_mask: bool = False,
482+
*,
483+
scale: Optional[float] = None,
484+
) -> Tuple[
485+
torch.Tensor,
486+
torch.Tensor,
487+
torch.Tensor,
488+
torch.Tensor,
489+
torch.SymInt,
490+
torch.SymInt,
491+
torch.Tensor,
492+
torch.Tensor,
493+
torch.Tensor,
494+
]:
495+
attn = scaled_dot_product_attention_decomposition(
496+
query, key, value, None, dropout_p, is_causal, scale=scale
497+
)
498+
return attn, None, None, None, 0, 0, None, None, None
499+
500+
501+
@register_torch_trt_decomposition(
502+
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
503+
)
504+
def scaled_dot_product_efficient_attention_decomposition(
505+
query: torch.Tensor,
506+
key: torch.Tensor,
507+
value: torch.Tensor,
508+
attn_bias: Optional[torch.Tensor],
509+
compute_log_sumexp: bool,
510+
dropout_p: float = 0.0,
511+
is_causal: bool = False,
512+
*,
513+
scale: Optional[float] = None,
514+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
515+
attn = scaled_dot_product_attention_decomposition(
516+
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
517+
)
518+
return attn, None, None, None
519+
520+
521+
@register_torch_trt_decomposition(
522+
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
523+
)
524+
def scaled_dot_product_cudnn_attention_decomposition(
525+
query: torch.Tensor,
526+
key: torch.Tensor,
527+
value: torch.Tensor,
528+
attn_bias: Optional[torch.Tensor],
529+
compute_log_sumexp: bool,
530+
dropout_p: float = 0.0,
531+
is_causal: bool = False,
532+
return_debug_mask: bool = False,
533+
*,
534+
scale: Optional[float] = None,
535+
) -> Tuple[
536+
torch.Tensor,
537+
torch.Tensor,
538+
torch.Tensor,
539+
torch.Tensor,
540+
torch.SymInt,
541+
torch.SymInt,
542+
torch.Tensor,
543+
torch.Tensor,
544+
torch.Tensor,
545+
]:
546+
attn = scaled_dot_product_attention_decomposition(
547+
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
548+
)
549+
return attn, None, None, None, 0, 0, None, None, None
550+
551+
426552
def get_decompositions(
427553
enable_experimental_decompositions: bool = False,
428554
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .constant_folding import constant_fold
99
from .fuse_prims_broadcast import fuse_prims_broadcast
1010
from .lower_linear import lower_linear
11-
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1211
from .pass_manager import DynamoPassManager
1312
from .remove_assert_scalar import remove_assert_scalar
1413
from .remove_detach import remove_detach
@@ -23,7 +22,6 @@
2322
remove_input_alias_fixing_clones,
2423
constant_fold,
2524
repair_input_as_output,
26-
lower_scaled_dot_product_attention,
2725
lower_linear,
2826
fuse_prims_broadcast,
2927
replace_max_pool_with_indices,

0 commit comments

Comments
 (0)