File tree Expand file tree Collapse file tree 2 files changed +9
-6
lines changed
py/torch_tensorrt/dynamo/lowering/passes
tests/py/dynamo/conversion Expand file tree Collapse file tree 2 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -45,15 +45,17 @@ def lower_scaled_dot_product_attention(
45
45
break
46
46
47
47
assert attention_node_replaced is not None
48
+ assert len (match .replacements ) == 1
49
+
50
+ new_attention_node = match .replacements [0 ]
51
+
52
+ assert (
53
+ new_attention_node .target
54
+ == torch .nn .functional .scaled_dot_product_attention
55
+ )
48
56
49
57
# If the attention operator had keyword-args, copy them to the new node
50
58
if attention_node_replaced .kwargs :
51
- assert len (match .replacements ) == 1
52
- new_attention_node = match .replacements [0 ]
53
- assert (
54
- new_attention_node .target
55
- == torch .nn .functional .scaled_dot_product_attention
56
- )
57
59
new_attention_node .kwargs = {** attention_node_replaced .kwargs }
58
60
59
61
# Set default args in new node:
Original file line number Diff line number Diff line change 5
5
from parameterized import parameterized
6
6
from torch .testing ._internal .common_utils import run_tests
7
7
8
+ from ..testing_utilities import DECIMALS_OF_AGREEMENT
8
9
from .harness import DispatchTestCase
9
10
10
11
You can’t perform that action at this time.
0 commit comments