Skip to content

Commit e6f9aa2

Browse files
authored
feat: Add support for is_causal argument in attention (#2780)
1 parent c48db6d commit e6f9aa2

File tree

5 files changed

+181
-8
lines changed

5 files changed

+181
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,8 +2357,14 @@ def aten_ops_max_pool(
23572357
)
23582358

23592359

2360+
def attention_validator(node: Node) -> bool:
2361+
# Currently, `attn_mask` is not supported
2362+
return args_bounds_check(node.args, 3) is None
2363+
2364+
23602365
@dynamo_tensorrt_converter(
23612366
torch.nn.functional.scaled_dot_product_attention,
2367+
capability_validator=attention_validator,
23622368
)
23632369
def tensorrt_scaled_dot_product_attention(
23642370
ctx: ConversionContext,
@@ -2375,6 +2381,7 @@ def tensorrt_scaled_dot_product_attention(
23752381
args[0],
23762382
args[1],
23772383
args[2],
2384+
args_bounds_check(args, 5, False),
23782385
kwargs.get("scale", None),
23792386
)
23802387

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import math
22
from typing import Optional, Union
33

4+
import numpy as np
45
import tensorrt as trt
56
from torch.fx.node import Target
7+
from torch_tensorrt._enums import dtype
68
from torch_tensorrt.dynamo.conversion import impl
79
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
10+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
911
from torch_tensorrt.fx.types import TRTTensor
1012

1113

@@ -17,8 +19,11 @@ def scaled_dot_product_attention(
1719
query: TRTTensor,
1820
key: TRTTensor,
1921
value: TRTTensor,
22+
is_causal: bool,
2023
scale: Optional[float],
2124
) -> TRTTensor:
25+
L, S = query.shape[-2], key.shape[-2]
26+
2227
mm = impl.matmul.matrix_multiply(
2328
ctx,
2429
target,
@@ -46,6 +51,17 @@ def scaled_dot_product_attention(
4651
mm,
4752
scale,
4853
)
54+
55+
if is_causal:
56+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
57+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
58+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
59+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
60+
61+
scaled = impl.elementwise.add(
62+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
63+
)
64+
4965
softmax = impl.normalization.softmax(
5066
ctx, target, source_ir, name + "_softmax", scaled, -1
5167
)

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, Sequence, Tuple
44

55
import torch
6+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
67
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
78
clean_up_graph_after_modifications,
89
)
@@ -34,6 +35,7 @@ def lower_scaled_dot_product_attention(
3435

3536
if replaced_nodes:
3637
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
38+
# Also repair instances which specified the is_causal or attn_bias fields
3739
for match in replaced_nodes:
3840
attention_node_replaced = None
3941
# Seek the attention operator being replaced
@@ -43,17 +45,52 @@ def lower_scaled_dot_product_attention(
4345
break
4446

4547
assert attention_node_replaced is not None
48+
assert len(match.replacements) == 1
49+
50+
new_attention_node = match.replacements[0]
51+
52+
assert (
53+
new_attention_node.target
54+
== torch.nn.functional.scaled_dot_product_attention
55+
)
4656

4757
# If the attention operator had keyword-args, copy them to the new node
4858
if attention_node_replaced.kwargs:
49-
assert len(match.replacements) == 1
50-
new_attention_node = match.replacements[0]
51-
assert (
52-
new_attention_node.target
53-
== torch.nn.functional.scaled_dot_product_attention
54-
)
5559
new_attention_node.kwargs = {**attention_node_replaced.kwargs}
5660

61+
# Set default args in new node:
62+
# Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False
63+
new_attention_node.args = new_attention_node.args + (None, 0.0, False)
64+
65+
# The `is_causal` argument was specified
66+
if (
67+
(
68+
attention_node_replaced.target
69+
== torch.ops.aten._scaled_dot_product_flash_attention.default
70+
)
71+
and args_bounds_check(attention_node_replaced.args, 4, False)
72+
) or (
73+
(
74+
attention_node_replaced.target
75+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
76+
)
77+
and args_bounds_check(attention_node_replaced.args, 6, False)
78+
):
79+
new_attention_node.args = (
80+
new_attention_node.args[:5] + (True,) + new_attention_node.args[6:]
81+
)
82+
83+
# The `attn_bias` argument was specified
84+
if (
85+
attention_node_replaced.target
86+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
87+
) and args_bounds_check(attention_node_replaced.args, 3) is not None:
88+
new_attention_node.args = (
89+
new_attention_node.args[:3]
90+
+ attention_node_replaced.args[3]
91+
+ new_attention_node.args[4:]
92+
)
93+
5794
gm = clean_up_graph_after_modifications(gm)
5895
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
5996

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
from ..testing_utilities import DECIMALS_OF_AGREEMENT
9+
from .harness import DispatchTestCase
10+
11+
12+
class TestScaledDotProductAttention(DispatchTestCase):
13+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
14+
def test_sdpa_no_causal(self, query_shape, key_shape):
15+
class SDPA(nn.Module):
16+
def forward(self, query, key, value):
17+
return torch.nn.functional.scaled_dot_product_attention(
18+
query, key, value, None, 0.0, False, scale=None
19+
)
20+
21+
inputs = []
22+
query = torch.randn(query_shape, dtype=torch.float16)
23+
key = torch.rand(key_shape, dtype=torch.float16)
24+
value = torch.rand(key_shape, dtype=torch.float16)
25+
inputs.extend([query, key, value])
26+
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)
27+
28+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
29+
def test_sdpa_causal(self, query_shape, key_shape):
30+
class SDPA(nn.Module):
31+
def forward(self, query, key, value):
32+
return torch.nn.functional.scaled_dot_product_attention(
33+
query, key, value, None, 0.0, True, scale=None
34+
)
35+
36+
inputs = []
37+
query = torch.randn(query_shape, dtype=torch.float16)
38+
key = torch.rand(key_shape, dtype=torch.float16)
39+
value = torch.rand(key_shape, dtype=torch.float16)
40+
inputs.extend([query, key, value])
41+
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)
42+
43+
44+
@unittest.skipIf(
45+
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8,
46+
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
47+
)
48+
class TestFlashAttention(DispatchTestCase):
49+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
50+
def test_sdpa_causal(self, query_shape, key_shape):
51+
class SDPA(nn.Module):
52+
def forward(self, query, key, value):
53+
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
54+
query,
55+
key,
56+
value,
57+
0,
58+
True, # is_causal
59+
False,
60+
scale=0.25,
61+
)
62+
return attn[0]
63+
64+
inputs = []
65+
query = torch.randn(query_shape, dtype=torch.float16)
66+
key = torch.rand(key_shape, dtype=torch.float16)
67+
value = torch.rand(key_shape, dtype=torch.float16)
68+
inputs.extend([query, key, value])
69+
self.run_test(
70+
SDPA(),
71+
inputs,
72+
rtol=1e-2,
73+
atol=1e-2,
74+
precision=torch.float16,
75+
enable_passes=True,
76+
)
77+
78+
79+
class TestEfficientAttention(DispatchTestCase):
80+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
81+
def test_sdpa_causal(self, query_shape, key_shape):
82+
class SDPA(nn.Module):
83+
def forward(self, query, key, value):
84+
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
85+
query,
86+
key,
87+
value,
88+
None,
89+
False,
90+
0,
91+
True, # is_causal
92+
scale=0.5,
93+
)
94+
return attn[0]
95+
96+
inputs = []
97+
query = torch.randn(query_shape, dtype=torch.float16)
98+
key = torch.rand(key_shape, dtype=torch.float16)
99+
value = torch.rand(key_shape, dtype=torch.float16)
100+
inputs.extend([query, key, value])
101+
self.run_test(
102+
SDPA(),
103+
inputs,
104+
rtol=1e-2,
105+
atol=1e-2,
106+
precision=torch.float16,
107+
enable_passes=True,
108+
)
109+
110+
111+
if __name__ == "__main__":
112+
run_tests()

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import unittest
22

33
import torch
4-
import torch_tensorrt
54
from torch.testing._internal.common_utils import TestCase, run_tests
65

6+
import torch_tensorrt
7+
78
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
89

910

0 commit comments

Comments
 (0)