Skip to content

Commit 2b1e840

Browse files
committed
utility function to detect tegra platform
1 parent 7393a85 commit 2b1e840

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.utils import is_tegra_platform
67

78
from .accumulate_fp32_matmul import accumulate_fp32_matmul
89
from .constant_folding import constant_fold
@@ -25,11 +26,11 @@
2526
replace_max_pool_with_indices,
2627
lower_scaled_dot_product_attention,
2728
view_to_reshape,
28-
remove_assert_nodes,
29+
remove_assert_scalar,
2930
accumulate_fp32_matmul,
3031
]
3132

32-
if torch.cuda.get_device_capability() not in [(8, 7), (7, 2)]:
33+
if not is_tegra_platform():
3334
pass_list.append(fuse_distributed_ops)
3435

3536
ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,9 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
793793
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
794794
)
795795
return output_dtypes
796+
797+
798+
def is_tegra_platform() -> bool:
799+
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
800+
return True
801+
return False

0 commit comments

Comments
 (0)