Skip to content

Commit c5cb551

Browse files
committed
Minor fix
1 parent 8ff834e commit c5cb551

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,17 @@ def lower_scaled_dot_product_attention(
4545
break
4646

4747
assert attention_node_replaced is not None
48+
assert len(match.replacements) == 1
49+
50+
new_attention_node = match.replacements[0]
51+
52+
assert (
53+
new_attention_node.target
54+
== torch.nn.functional.scaled_dot_product_attention
55+
)
4856

4957
# If the attention operator had keyword-args, copy them to the new node
5058
if attention_node_replaced.kwargs:
51-
assert len(match.replacements) == 1
52-
new_attention_node = match.replacements[0]
53-
assert (
54-
new_attention_node.target
55-
== torch.nn.functional.scaled_dot_product_attention
56-
)
5759
new_attention_node.kwargs = {**attention_node_replaced.kwargs}
5860

5961
# Set default args in new node:

tests/py/dynamo/conversion/test_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
77

8+
from ..testing_utilities import DECIMALS_OF_AGREEMENT
89
from .harness import DispatchTestCase
910

1011

0 commit comments

Comments
 (0)