Skip to content

Commit 1840e36

Browse files
authored
fix: Remove pytorch overhead while finding fusions for fully convertible models (#3311)
1 parent 3dded9d commit 1840e36

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
765765
min_block_size=settings.min_block_size,
766766
torch_executed_ops=settings.torch_executed_ops,
767767
require_full_compilation=settings.require_full_compilation,
768+
skip_fusion=(num_supported_ops == total_ops),
768769
)
770+
769771
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
770772
logger.error(
771773
"Partitioning failed on the subgraph with fast partition. See trace above. "

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
min_block_size: int = MIN_BLOCK_SIZE,
112112
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
113113
return_tuple: bool = False,
114+
skip_fusion: bool = False,
114115
):
115116
"""
116117
Preprocesses graph before splitting:
@@ -127,6 +128,7 @@ def __init__(
127128
self.settings = _SplitterSettingBase(
128129
min_acc_module_size=min_block_size,
129130
allow_non_tensor=True,
131+
skip_fusion=skip_fusion,
130132
)
131133
self.operator_support = operator_support
132134

@@ -252,6 +254,7 @@ def partition(
252254
min_block_size: int = MIN_BLOCK_SIZE,
253255
torch_executed_ops: Collection[Target] = set(),
254256
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
257+
skip_fusion: bool = False,
255258
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
256259
"""Partition an FX GraphModule with aten ops into TRT engines
257260
Partitioning is based on converter operator support
@@ -262,6 +265,7 @@ def partition(
262265
min_block_size: Minimum number of operators per TRT-Engine Block
263266
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
264267
require_full_compilation: Require that all computational operators be run in TRT
268+
skip_fusion: Skip fusions found by FxNetAccFusionsFinder
265269
Returns:
266270
torch.fx.GraphModule, OpSupportTester
267271
"""
@@ -277,6 +281,7 @@ def partition(
277281
supported_ops,
278282
min_block_size=min_block_size,
279283
require_full_compilation=require_full_compilation,
284+
skip_fusion=skip_fusion,
280285
)
281286

282287
partitioned_graph = partitioner.partition_graph()

0 commit comments

Comments
 (0)