Skip to content

Commit 8ff834e

Browse files
committed
feat: Add support for is_causal in attention
1 parent 67675d7 commit 8ff834e

File tree

5 files changed

+168
-2
lines changed

5 files changed

+168
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,8 +2357,14 @@ def aten_ops_max_pool(
23572357
)
23582358

23592359

2360+
def attention_validator(node: Node) -> bool:
2361+
# Currently, `attn_mask` is not supported
2362+
return args_bounds_check(node.args, 3) is None
2363+
2364+
23602365
@dynamo_tensorrt_converter(
23612366
torch.nn.functional.scaled_dot_product_attention,
2367+
capability_validator=attention_validator,
23622368
)
23632369
def tensorrt_scaled_dot_product_attention(
23642370
ctx: ConversionContext,
@@ -2375,6 +2381,7 @@ def tensorrt_scaled_dot_product_attention(
23752381
args[0],
23762382
args[1],
23772383
args[2],
2384+
args_bounds_check(args, 5, False),
23782385
kwargs.get("scale", None),
23792386
)
23802387

py/torch_tensorrt/dynamo/conversion/impl/attention.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import math
22
from typing import Optional, Union
33

4+
import numpy as np
45
import tensorrt as trt
56
from torch.fx.node import Target
7+
from torch_tensorrt._enums import dtype
68
from torch_tensorrt.dynamo.conversion import impl
79
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
10+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
911
from torch_tensorrt.fx.types import TRTTensor
1012

1113

@@ -17,8 +19,11 @@ def scaled_dot_product_attention(
1719
query: TRTTensor,
1820
key: TRTTensor,
1921
value: TRTTensor,
22+
is_causal: bool,
2023
scale: Optional[float],
2124
) -> TRTTensor:
25+
L, S = query.shape[-2], key.shape[-2]
26+
2227
mm = impl.matmul.matrix_multiply(
2328
ctx,
2429
target,
@@ -46,6 +51,17 @@ def scaled_dot_product_attention(
4651
mm,
4752
scale,
4853
)
54+
55+
if is_causal:
56+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
57+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
58+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
59+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
60+
61+
scaled = impl.elementwise.add(
62+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
63+
)
64+
4965
softmax = impl.normalization.softmax(
5066
ctx, target, source_ir, name + "_softmax", scaled, -1
5167
)

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, Sequence, Tuple
44

55
import torch
6+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
67
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
78
clean_up_graph_after_modifications,
89
)
@@ -34,6 +35,7 @@ def lower_scaled_dot_product_attention(
3435

3536
if replaced_nodes:
3637
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
38+
# Also repair instances which specified the is_causal or attn_bias fields
3739
for match in replaced_nodes:
3840
attention_node_replaced = None
3941
# Seek the attention operator being replaced
@@ -54,6 +56,39 @@ def lower_scaled_dot_product_attention(
5456
)
5557
new_attention_node.kwargs = {**attention_node_replaced.kwargs}
5658

59+
# Set default args in new node:
60+
# Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False
61+
new_attention_node.args = new_attention_node.args + (None, 0.0, False)
62+
63+
# The `is_causal` argument was specified
64+
if (
65+
(
66+
attention_node_replaced.target
67+
== torch.ops.aten._scaled_dot_product_flash_attention.default
68+
)
69+
and args_bounds_check(attention_node_replaced.args, 4, False)
70+
) or (
71+
(
72+
attention_node_replaced.target
73+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
74+
)
75+
and args_bounds_check(attention_node_replaced.args, 6, False)
76+
):
77+
new_attention_node.args = (
78+
new_attention_node.args[:5] + (True,) + new_attention_node.args[6:]
79+
)
80+
81+
# The `attn_bias` argument was specified
82+
if (
83+
attention_node_replaced.target
84+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
85+
) and args_bounds_check(attention_node_replaced.args, 3) is not None:
86+
new_attention_node.args = (
87+
new_attention_node.args[:3]
88+
+ attention_node_replaced.args[3]
89+
+ new_attention_node.args[4:]
90+
)
91+
5792
gm = clean_up_graph_after_modifications(gm)
5893
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
5994

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
from .harness import DispatchTestCase
9+
10+
11+
class TestScaledDotProductAttention(DispatchTestCase):
12+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
13+
def test_sdpa_no_causal(self, query_shape, key_shape):
14+
class SDPA(nn.Module):
15+
def forward(self, query, key, value):
16+
return torch.nn.functional.scaled_dot_product_attention(
17+
query, key, value, None, 0.0, False, scale=None
18+
)
19+
20+
inputs = []
21+
query = torch.randn(query_shape, dtype=torch.float16)
22+
key = torch.rand(key_shape, dtype=torch.float16)
23+
value = torch.rand(key_shape, dtype=torch.float16)
24+
inputs.extend([query, key, value])
25+
self.run_test(SDPA(), inputs, precision=torch.float16)
26+
27+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
28+
def test_sdpa_causal(self, query_shape, key_shape):
29+
class SDPA(nn.Module):
30+
def forward(self, query, key, value):
31+
return torch.nn.functional.scaled_dot_product_attention(
32+
query, key, value, None, 0.0, True, scale=None
33+
)
34+
35+
inputs = []
36+
query = torch.randn(query_shape, dtype=torch.float16)
37+
key = torch.rand(key_shape, dtype=torch.float16)
38+
value = torch.rand(key_shape, dtype=torch.float16)
39+
inputs.extend([query, key, value])
40+
self.run_test(SDPA(), inputs, precision=torch.float16)
41+
42+
43+
@unittest.skipIf(
44+
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8,
45+
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
46+
)
47+
class TestFlashAttention(DispatchTestCase):
48+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
49+
def test_sdpa_causal(self, query_shape, key_shape):
50+
class SDPA(nn.Module):
51+
def forward(self, query, key, value):
52+
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
53+
query,
54+
key,
55+
value,
56+
0,
57+
True, # is_causal
58+
False,
59+
scale=0.25,
60+
)
61+
return attn[0]
62+
63+
inputs = []
64+
query = torch.randn(query_shape, dtype=torch.float16)
65+
key = torch.rand(key_shape, dtype=torch.float16)
66+
value = torch.rand(key_shape, dtype=torch.float16)
67+
inputs.extend([query, key, value])
68+
self.run_test(
69+
SDPA(),
70+
inputs,
71+
precision=torch.float16,
72+
enable_passes=True,
73+
)
74+
75+
76+
class TestEfficientAttention(DispatchTestCase):
77+
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
78+
def test_sdpa_causal(self, query_shape, key_shape):
79+
class SDPA(nn.Module):
80+
def forward(self, query, key, value):
81+
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
82+
query,
83+
key,
84+
value,
85+
None,
86+
False,
87+
0,
88+
True, # is_causal
89+
scale=0.5,
90+
)
91+
return attn[0]
92+
93+
inputs = []
94+
query = torch.randn(query_shape, dtype=torch.float16)
95+
key = torch.rand(key_shape, dtype=torch.float16)
96+
value = torch.rand(key_shape, dtype=torch.float16)
97+
inputs.extend([query, key, value])
98+
self.run_test(
99+
SDPA(),
100+
inputs,
101+
precision=torch.float16,
102+
enable_passes=True,
103+
)
104+
105+
106+
if __name__ == "__main__":
107+
run_tests()

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import unittest
22

33
import torch
4-
import torch_tensorrt
54
from torch.testing._internal.common_utils import TestCase, run_tests
65

6+
import torch_tensorrt
7+
78
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
89

910

0 commit comments

Comments
 (0)