Skip to content

Commit 7393a85

Browse files
committed
removing the fuse distributed ops lowering pass for tegra platforms
1 parent 73a7aac commit 7393a85

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,22 @@
1717
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1818
from .view_to_reshape import view_to_reshape
1919

20-
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
21-
[
22-
remove_input_alias_fixing_clones,
23-
constant_fold,
24-
repair_input_as_output,
25-
fuse_prims_broadcast,
26-
fuse_distributed_ops,
27-
replace_max_pool_with_indices,
28-
lower_scaled_dot_product_attention,
29-
view_to_reshape,
30-
remove_assert_scalar,
31-
accumulate_fp32_matmul,
32-
]
33-
)
20+
pass_list = [
21+
remove_input_alias_fixing_clones,
22+
constant_fold,
23+
repair_input_as_output,
24+
fuse_prims_broadcast,
25+
replace_max_pool_with_indices,
26+
lower_scaled_dot_product_attention,
27+
view_to_reshape,
28+
remove_assert_nodes,
29+
accumulate_fp32_matmul,
30+
]
31+
32+
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
33+
pass_list.append(fuse_distributed_ops)
34+
35+
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)
3436

3537
ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
3638
[

0 commit comments

Comments
 (0)