Skip to content

Commit 68c2d45

Browse files
authored
chore: revert attention decomposition due to flux bug (#3332)
1 parent 8eff5a6 commit 68c2d45

File tree

8 files changed

+623
-478
lines changed

8 files changed

+623
-478
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,6 +2730,38 @@ def aten_ops_max_pool(
27302730
)
27312731

27322732

2733+
def attention_validator(
2734+
node: Node, settings: Optional[CompilationSettings] = None
2735+
) -> bool:
2736+
# Currently, `attn_mask` is not supported
2737+
return args_bounds_check(node.args, 3) is None
2738+
2739+
2740+
@dynamo_tensorrt_converter(
2741+
torch.nn.functional.scaled_dot_product_attention,
2742+
capability_validator=attention_validator,
2743+
supports_dynamic_shapes=True,
2744+
)
2745+
def tensorrt_scaled_dot_product_attention(
2746+
ctx: ConversionContext,
2747+
target: Target,
2748+
args: Tuple[Argument, ...],
2749+
kwargs: Dict[str, Argument],
2750+
name: str,
2751+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2752+
return impl.attention.scaled_dot_product_attention(
2753+
ctx,
2754+
target,
2755+
SourceIR.TORCHTRT_LOWERED,
2756+
name,
2757+
args[0],
2758+
args[1],
2759+
args[2],
2760+
args_bounds_check(args, 5, False),
2761+
kwargs.get("scale", None),
2762+
)
2763+
2764+
27332765
@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True)
27342766
@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True)
27352767
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
activation,
33
addmm,
44
arange,
5+
attention,
56
cast,
67
cat,
78
condition,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import math
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
import tensorrt as trt
6+
from torch.fx.node import Target
7+
from torch_tensorrt._enums import dtype
8+
from torch_tensorrt.dynamo.conversion import impl
9+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
SourceIR,
12+
cast_trt_tensor,
13+
get_trt_tensor,
14+
)
15+
from torch_tensorrt.fx.types import TRTTensor
16+
17+
18+
def tril(
19+
ctx: ConversionContext,
20+
target: Union[Target, str],
21+
source_ir: Optional[SourceIR],
22+
name: str,
23+
input: TRTTensor,
24+
) -> TRTTensor:
25+
# the lower triangle of the tensor means the rows greater than and equal to the cols
26+
row = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", input, 0)
27+
col = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", input, 1)
28+
rc = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", row, col)
29+
arange_tensor = impl.arange.arange(
30+
ctx, target, source_ir, name + "_arange", start=0, end=rc, step=1
31+
)
32+
# get the rows
33+
row_tensor = impl.elementwise.trunc_div(
34+
ctx, target, source_ir, name + "_trunc_div_col", arange_tensor, col
35+
)
36+
# get the cols
37+
col_tensor = impl.elementwise.fmod(
38+
ctx, target, source_ir, name + "_trunc_div_row", arange_tensor, col
39+
)
40+
cond = impl.elementwise.ge(
41+
ctx, target, source_ir, name + "_ge", row_tensor, col_tensor
42+
)
43+
return impl.shuffle.reshape(
44+
ctx, target, source_ir, name + "_reshape", cond, [row, col]
45+
)
46+
47+
48+
def scaled_dot_product_attention(
49+
ctx: ConversionContext,
50+
target: Union[Target, str],
51+
source_ir: Optional[SourceIR],
52+
name: str,
53+
query: TRTTensor,
54+
key: TRTTensor,
55+
value: TRTTensor,
56+
is_causal: bool,
57+
scale: Optional[float],
58+
) -> TRTTensor:
59+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
60+
mm = impl.matmul.matrix_multiply(
61+
ctx,
62+
target,
63+
source_ir,
64+
name + "_mm",
65+
query,
66+
key,
67+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
68+
)
69+
if scale is None:
70+
scale = query.shape[-1]
71+
if scale < 0:
72+
# dynamic shape
73+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
74+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
75+
else:
76+
# static shape
77+
sqrt_scaled = math.sqrt(scale)
78+
scaled = impl.elementwise.div(
79+
ctx,
80+
target,
81+
source_ir,
82+
name + "_scale",
83+
mm,
84+
sqrt_scaled,
85+
)
86+
else:
87+
scaled = impl.elementwise.mul(
88+
ctx,
89+
target,
90+
source_ir,
91+
name + "_scale",
92+
mm,
93+
scale,
94+
)
95+
96+
if is_causal:
97+
L, S = query.shape[-2], key.shape[-2]
98+
if L >= 0 and S >= 0:
99+
# static shape
100+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
101+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
102+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
103+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
104+
else:
105+
# if any of the L or S is dynamic shape
106+
if L < 0:
107+
L = impl.shape.shape(
108+
ctx, target, source_ir, name + "_shape_0", query, -2
109+
)
110+
if S < 0:
111+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, -2)
112+
113+
LS = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", L, S)
114+
115+
# this is to generate a tensor which has shape (L, S), type is int32
116+
arange_tensor = impl.arange.arange(
117+
ctx, target, source_ir, name=name + "_arange", start=0, end=LS, step=1
118+
)
119+
shape_tensor = impl.shuffle.reshape(
120+
ctx, target, source_ir, name + "_reshape", arange_tensor, [L, S]
121+
)
122+
123+
# since we want our attn_bias to be in float32, so cast it to float32
124+
shape_tensor = cast_trt_tensor(
125+
ctx, shape_tensor, trt.float32, name + "_casted", target, source_ir
126+
)
127+
128+
# initialize the attn_bias as the zeros tensor
129+
attn_bias = impl.elementwise.mul(
130+
ctx, target, source_ir, name + "_mul_zero", shape_tensor, 0.0
131+
)
132+
133+
# generate the mask tensor
134+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", shape_tensor)
135+
temp_mask = impl.unary.logical_not(
136+
ctx, target, source_ir, name + "_logical_not", tril_tensor
137+
)
138+
inf_tensor = impl.elementwise.mul(
139+
ctx, target, source_ir, name + "_mul_-inf", shape_tensor, float("-inf")
140+
)
141+
cond = impl.elementwise.eq(
142+
ctx, target, source_ir, name + "_cond_true", temp_mask, bool(True)
143+
)
144+
# mask out the certain part of the attn_bias
145+
attn_bias = impl.condition.select(
146+
ctx, target, source_ir, name + "_select", inf_tensor, attn_bias, cond
147+
)
148+
149+
scaled = impl.elementwise.add(
150+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
151+
)
152+
153+
softmax = impl.normalization.softmax(
154+
ctx, target, source_ir, name + "_softmax", scaled, -1, False
155+
)
156+
out = impl.matmul.matrix_multiply(
157+
ctx,
158+
target,
159+
source_ir,
160+
name + "_out",
161+
softmax,
162+
value,
163+
)
164+
165+
return out

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 2 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from enum import Enum, auto
3-
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional
44

55
import torch
66
from torch._decomp import register_decomposition
@@ -423,135 +423,9 @@ def instance_norm_decomposition(
423423
)
424424

425425

426-
@register_torch_trt_decomposition(
427-
aten.scaled_dot_product_attention, registry=TORCH_TRT_DECOMPOSITIONS
428-
)
429-
def scaled_dot_product_attention_decomposition(
430-
query: torch.Tensor,
431-
key: torch.Tensor,
432-
value: torch.Tensor,
433-
attn_mask: Optional[torch.Tensor] = None,
434-
dropout_p: float = 0.0,
435-
is_causal: bool = False,
436-
*,
437-
scale: Optional[float] = None,
438-
enable_gqa: bool = False,
439-
) -> torch.Tensor:
440-
L, S = query.size(-2), key.size(-2)
441-
device = query.device
442-
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
443-
444-
if is_causal:
445-
assert attn_mask is None, "attn_mask must be None when is_causal=True"
446-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
447-
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
448-
449-
if attn_mask is not None:
450-
if attn_mask.dtype == torch.bool:
451-
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
452-
else:
453-
attn_bias = attn_mask + attn_bias
454-
455-
if enable_gqa:
456-
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
457-
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
458-
459-
attn_weight = query @ key.transpose(-2, -1)
460-
461-
if scale is None:
462-
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
463-
attn_weight = attn_weight / scale
464-
else:
465-
attn_weight = attn_weight * scale
466-
467-
attn_weight = attn_weight + attn_bias
468-
attn_weight = torch.softmax(attn_weight, dim=-1)
469-
return attn_weight @ value
470-
471-
472-
@register_torch_trt_decomposition(
473-
aten._scaled_dot_product_flash_attention, registry=TORCH_TRT_DECOMPOSITIONS
474-
)
475-
def scaled_dot_product_flash_attention_decomposition(
476-
query: torch.Tensor,
477-
key: torch.Tensor,
478-
value: torch.Tensor,
479-
dropout_p: float = 0.0,
480-
is_causal: bool = False,
481-
return_debug_mask: bool = False,
482-
*,
483-
scale: Optional[float] = None,
484-
) -> Tuple[
485-
torch.Tensor,
486-
torch.Tensor,
487-
torch.Tensor,
488-
torch.Tensor,
489-
torch.SymInt,
490-
torch.SymInt,
491-
torch.Tensor,
492-
torch.Tensor,
493-
torch.Tensor,
494-
]:
495-
attn = scaled_dot_product_attention_decomposition(
496-
query, key, value, None, dropout_p, is_causal, scale=scale
497-
)
498-
return attn, None, None, None, 0, 0, None, None, None
499-
500-
501-
@register_torch_trt_decomposition(
502-
aten._scaled_dot_product_efficient_attention, registry=TORCH_TRT_DECOMPOSITIONS
503-
)
504-
def scaled_dot_product_efficient_attention_decomposition(
505-
query: torch.Tensor,
506-
key: torch.Tensor,
507-
value: torch.Tensor,
508-
attn_bias: Optional[torch.Tensor],
509-
compute_log_sumexp: bool,
510-
dropout_p: float = 0.0,
511-
is_causal: bool = False,
512-
*,
513-
scale: Optional[float] = None,
514-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
515-
attn = scaled_dot_product_attention_decomposition(
516-
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
517-
)
518-
return attn, None, None, None
519-
520-
521-
@register_torch_trt_decomposition(
522-
aten._scaled_dot_product_cudnn_attention, registry=TORCH_TRT_DECOMPOSITIONS
523-
)
524-
def scaled_dot_product_cudnn_attention_decomposition(
525-
query: torch.Tensor,
526-
key: torch.Tensor,
527-
value: torch.Tensor,
528-
attn_bias: Optional[torch.Tensor],
529-
compute_log_sumexp: bool,
530-
dropout_p: float = 0.0,
531-
is_causal: bool = False,
532-
return_debug_mask: bool = False,
533-
*,
534-
scale: Optional[float] = None,
535-
) -> Tuple[
536-
torch.Tensor,
537-
torch.Tensor,
538-
torch.Tensor,
539-
torch.Tensor,
540-
torch.SymInt,
541-
torch.SymInt,
542-
torch.Tensor,
543-
torch.Tensor,
544-
torch.Tensor,
545-
]:
546-
attn = scaled_dot_product_attention_decomposition(
547-
query, key, value, attn_bias, dropout_p, is_causal, scale=scale
548-
)
549-
return attn, None, None, None, 0, 0, None, None, None
550-
551-
552426
@register_torch_trt_decomposition(
553427
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
554-
)
428+
) # type: ignore
555429
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
556430
input = args[0]
557431
shape = args[0].shape

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .accumulate_fp32_matmul import accumulate_fp32_matmul
88
from .constant_folding import constant_fold
99
from .fuse_prims_broadcast import fuse_prims_broadcast
10+
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1011
from .pass_manager import DynamoPassManager
1112
from .remove_assert_scalar import remove_assert_scalar
1213
from .remove_detach import remove_detach
@@ -22,6 +23,7 @@
2223
repair_input_as_output,
2324
fuse_prims_broadcast,
2425
replace_max_pool_with_indices,
26+
lower_scaled_dot_product_attention,
2527
view_to_reshape,
2628
remove_assert_scalar,
2729
accumulate_fp32_matmul,

0 commit comments

Comments
 (0)