@@ -111,6 +111,7 @@ def __init__(
111
111
min_block_size : int = MIN_BLOCK_SIZE ,
112
112
require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
113
113
return_tuple : bool = False ,
114
+ skip_fusion : bool = False ,
114
115
):
115
116
"""
116
117
Preprocesses graph before splitting:
@@ -127,6 +128,7 @@ def __init__(
127
128
self .settings = _SplitterSettingBase (
128
129
min_acc_module_size = min_block_size ,
129
130
allow_non_tensor = True ,
131
+ skip_fusion = skip_fusion ,
130
132
)
131
133
self .operator_support = operator_support
132
134
@@ -252,6 +254,7 @@ def partition(
252
254
min_block_size : int = MIN_BLOCK_SIZE ,
253
255
torch_executed_ops : Collection [Target ] = set (),
254
256
require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
257
+ skip_fusion : bool = False ,
255
258
) -> Tuple [torch .fx .GraphModule , OpSupportTester ]:
256
259
"""Partition an FX GraphModule with aten ops into TRT engines
257
260
Partitioning is based on converter operator support
@@ -262,6 +265,7 @@ def partition(
262
265
min_block_size: Minimum number of operators per TRT-Engine Block
263
266
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
264
267
require_full_compilation: Require that all computational operators be run in TRT
268
+ skip_fusion: Skip fusions found by FxNetAccFusionsFinder
265
269
Returns:
266
270
torch.fx.GraphModule, OpSupportTester
267
271
"""
@@ -277,6 +281,7 @@ def partition(
277
281
supported_ops ,
278
282
min_block_size = min_block_size ,
279
283
require_full_compilation = require_full_compilation ,
284
+ skip_fusion = skip_fusion ,
280
285
)
281
286
282
287
partitioned_graph = partitioner .partition_graph ()
0 commit comments