Skip to content

Commit 8efbee9

Browse files
authored
feat: Add support for require_full_compilation in Dynamo (#2138)
1 parent b9c8578 commit 8efbee9

File tree

8 files changed

+150
-15
lines changed

8 files changed

+150
-15
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
USE_PYTHON_RUNTIME = False
1515
USE_FAST_PARTITIONER = True
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
17+
REQUIRE_FULL_COMPILATION = False
1718

1819

1920
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
OPTIMIZATION_LEVEL,
1212
PASS_THROUGH_BUILD_FAILURES,
1313
PRECISION,
14+
REQUIRE_FULL_COMPILATION,
1415
TRUNCATE_LONG_AND_DOUBLE,
1516
USE_FAST_PARTITIONER,
1617
USE_PYTHON_RUNTIME,
@@ -57,3 +58,4 @@ class CompilationSettings:
5758
use_fast_partitioner: bool = USE_FAST_PARTITIONER
5859
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
5960
device: Device = field(default_factory=default_device)
61+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION

py/torch_tensorrt/dynamo/compile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
OPTIMIZATION_LEVEL,
2121
PASS_THROUGH_BUILD_FAILURES,
2222
PRECISION,
23+
REQUIRE_FULL_COMPILATION,
2324
TRUNCATE_LONG_AND_DOUBLE,
2425
USE_FAST_PARTITIONER,
2526
USE_PYTHON_RUNTIME,
@@ -57,7 +58,7 @@ def compile(
5758
dla_global_dram_size: int = 536870912,
5859
calibrator: object = None,
5960
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
60-
require_full_compilation: bool = False,
61+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
6162
min_block_size: int = MIN_BLOCK_SIZE,
6263
torch_executed_ops: Optional[List[str]] = None,
6364
torch_executed_modules: Optional[List[str]] = None,
@@ -80,8 +81,10 @@ def compile(
8081
"The Dynamo backend is an experimental feature, for which only the "
8182
"following arguments are supported: "
8283
"{enabled_precisions, debug, workspace_size, min_block_size, "
83-
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
84-
"enable_experimental_decompositions}"
84+
"max_aux_streams, version_compatible, optimization_level, "
85+
"torch_executed_ops, pass_through_build_failures, "
86+
"use_fast_partitioner, enable_experimental_decompositions, "
87+
"require_full_compilation}"
8588
)
8689

8790
if not isinstance(inputs, collections.abc.Sequence):
@@ -126,6 +129,7 @@ def compile(
126129
"truncate_long_and_double": truncate_long_and_double,
127130
"use_fast_partitioner": use_fast_partitioner,
128131
"enable_experimental_decompositions": enable_experimental_decompositions,
132+
"require_full_compilation": require_full_compilation,
129133
}
130134

131135
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
_SplitterSettingBase,
1313
)
1414
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
15-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
15+
from torch_tensorrt.dynamo._defaults import (
16+
DEBUG,
17+
MIN_BLOCK_SIZE,
18+
REQUIRE_FULL_COMPILATION,
19+
)
1620
from torch_tensorrt.dynamo.conversion.converter_registry import (
1721
DYNAMO_CONVERTERS as CONVERTERS,
1822
)
@@ -92,6 +96,7 @@ class TRTPartitioner(_SplitterBase): # type: ignore
9296
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
9397
Generally useful for module-level exclusion ops which are intensive despite being single functions
9498
min_block_size: Minimum number of computational operators per block
99+
require_full_compilation: Require that all computational operators be run in TRT
95100
Returns:
96101
torch.fx.GraphModule
97102
"""
@@ -104,6 +109,7 @@ def __init__(
104109
Collection[str]
105110
] = DEFAULT_SINGLE_NODE_PARTITIONS,
106111
min_block_size: int = MIN_BLOCK_SIZE,
112+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
107113
):
108114
"""
109115
Preprocesses graph before splitting:
@@ -142,6 +148,7 @@ def __init__(
142148

143149
self.num_trt_accelerated_subgraphs: Optional[int] = None
144150
self.allowed_single_node_partition_ops = allowed_single_node_partition_ops
151+
self.require_full_compilation = require_full_compilation
145152

146153
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
147154
"""
@@ -151,12 +158,16 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph
151158
result: List[Subgraph] = []
152159
for subgraph in subgraphs:
153160
if subgraph.is_acc:
154-
if len(subgraph.nodes) >= self.settings.min_acc_module_size or (
155-
self.allowed_single_node_partition_ops is not None
156-
and any(
157-
ConverterRegistry.qualified_name_or_str(node.target)
158-
in self.allowed_single_node_partition_ops
159-
for node in subgraph.nodes
161+
if (
162+
len(subgraph.nodes) >= self.settings.min_acc_module_size
163+
or self.require_full_compilation
164+
or (
165+
self.allowed_single_node_partition_ops is not None
166+
and any(
167+
ConverterRegistry.qualified_name_or_str(node.target)
168+
in self.allowed_single_node_partition_ops
169+
for node in subgraph.nodes
170+
)
160171
)
161172
):
162173
result.append(subgraph)
@@ -185,6 +196,27 @@ def partition_graph(self) -> torch.fx.GraphModule:
185196
# Delegate nodes based on operator coverage
186197
subgraphs = self.put_nodes_into_subgraphs()
187198

199+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
200+
full_support = len([s for s in subgraphs if s.is_acc]) == 1 and not getattr(
201+
self.operator_support, "unsupported_operators", True
202+
)
203+
204+
if not full_support and self.require_full_compilation:
205+
raise AssertionError(
206+
"require_full_compilation=True was specified, but model is not fully supported"
207+
)
208+
209+
if (
210+
full_support
211+
and self.require_full_compilation
212+
and self.settings.min_acc_module_size != MIN_BLOCK_SIZE
213+
):
214+
logger.warning(
215+
"Detected both require_full_compilation and min_block_size compilation "
216+
"arguments were specified. Disregarding min_block_size argument for "
217+
"fully supported model."
218+
)
219+
188220
# Remove segments smaller than the block size (with exceptions)
189221
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
190222

@@ -217,6 +249,7 @@ def partition(
217249
verbose: bool = DEBUG,
218250
min_block_size: int = MIN_BLOCK_SIZE,
219251
torch_executed_ops: Collection[Target] = set(),
252+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
220253
) -> torch.fx.GraphModule:
221254
"""Partition an FX GraphModule with aten ops into TRT engines
222255
Partitioning is based on converter operator support
@@ -226,6 +259,7 @@ def partition(
226259
verbose: Bool representing whether to print operator support
227260
min_block_size: Minimum number of operators per TRT-Engine Block
228261
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
262+
require_full_compilation: Require that all computational operators be run in TRT
229263
Returns:
230264
torch.fx.GraphModule
231265
"""
@@ -236,7 +270,12 @@ def partition(
236270

237271
# Construct
238272
supported_ops = OpSupportTester(torch_executed_ops=torch_executed_ops)
239-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
273+
partitioner = TRTPartitioner(
274+
gm,
275+
supported_ops,
276+
min_block_size=min_block_size,
277+
require_full_compilation=require_full_compilation,
278+
)
240279

241280
partitioned_graph = partitioner.partition_graph()
242281

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from torch.fx.graph_module import GraphModule
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
77
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
8-
from torch_tensorrt.dynamo._defaults import DEBUG, MIN_BLOCK_SIZE
8+
from torch_tensorrt.dynamo._defaults import (
9+
DEBUG,
10+
MIN_BLOCK_SIZE,
11+
REQUIRE_FULL_COMPILATION,
12+
)
913
from torch_tensorrt.dynamo.conversion.converter_registry import (
1014
DYNAMO_CONVERTERS as CONVERTERS,
1115
)
@@ -26,6 +30,7 @@ class TRTPartitioner(CapabilityBasedPartitioner): # type: ignore[misc]
2630
allowed_single_node_partition_ops: Nodes which can be included in single-node partitons.
2731
Generally useful for module-level exclusion ops which are intensive despite being single functions
2832
min_block_size: Minimum number of computational operators per block
33+
require_full_compilation: Require that all computational operators be run in TRT
2934
Returns:
3035
torch.fx.GraphModule
3136
"""
@@ -40,6 +45,7 @@ def __init__(
4045
Collection[str]
4146
] = DEFAULT_SINGLE_NODE_PARTITIONS,
4247
min_block_size: int = MIN_BLOCK_SIZE,
48+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
4349
) -> None:
4450
super().__init__(
4551
graph_module,
@@ -50,12 +56,34 @@ def __init__(
5056
)
5157

5258
self.min_block_size = min_block_size
59+
self.require_full_compilation = require_full_compilation
5360

5461
def propose_partitions(self) -> List[Partition]:
5562
# Propose partitions using the default, then refine the results
5663
initial_proposed_partitions = super().propose_partitions()
5764
partitions = dict(enumerate(initial_proposed_partitions))
5865

66+
# A graph is fully supported if there is a single partition and all operators are supported/convertible
67+
full_support = len(partitions) == 1 and not getattr(
68+
self.operator_support, "unsupported_operators", True
69+
)
70+
71+
if not full_support and self.require_full_compilation:
72+
raise AssertionError(
73+
"require_full_compilation=True was specified, but model is not fully supported"
74+
)
75+
76+
if (
77+
full_support
78+
and self.require_full_compilation
79+
and self.min_block_size != MIN_BLOCK_SIZE
80+
):
81+
logger.warning(
82+
"Detected both require_full_compilation and min_block_size compilation "
83+
"arguments were specified. Disregarding min_block_size argument for "
84+
"fully supported model."
85+
)
86+
5987
# For each partition, determine whether or not the number of computational operators
6088
# exceeds the threshold, and if not, remove that partition
6189
partitions_to_remove = {}
@@ -81,7 +109,11 @@ def propose_partitions(self) -> List[Partition]:
81109
):
82110
compute_node_count += 1
83111

84-
if compute_node_count < self.min_block_size and not exempted_partition:
112+
if (
113+
compute_node_count < self.min_block_size
114+
and not exempted_partition
115+
and not (full_support and self.require_full_compilation)
116+
):
85117
partitions_to_remove[id] = compute_node_count
86118

87119
# Remove any nodes violating the criteria specified by the user
@@ -172,6 +204,7 @@ def partition(
172204
verbose: bool = DEBUG,
173205
min_block_size: int = MIN_BLOCK_SIZE,
174206
torch_executed_ops: Optional[Set[str]] = None,
207+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
175208
) -> torch.fx.GraphModule:
176209
"""Partition an FX GraphModule with aten ops into TRT engines
177210
Partitioning is based on converter operator support
@@ -181,6 +214,7 @@ def partition(
181214
verbose: Bool representing whether to print operator support
182215
min_block_size: Minimum number of operators per TRT-Engine Block
183216
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
217+
require_full_compilation: Whether to require that all operators be run in TRT
184218
Returns:
185219
torch.fx.GraphModule
186220
"""
@@ -189,7 +223,12 @@ def partition(
189223
if torch_executed_ops is not None
190224
else set()
191225
)
192-
partitioner = TRTPartitioner(gm, supported_ops, min_block_size=min_block_size)
226+
partitioner = TRTPartitioner(
227+
gm,
228+
supported_ops,
229+
min_block_size=min_block_size,
230+
require_full_compilation=require_full_compilation,
231+
)
193232

194233
# Determine partitions based on user specifications and operator support
195234
# Then, fuse partitions and display overview of supported/unsupported operators

py/torch_tensorrt/dynamo/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
216216
"If this is incorrect, please specify an input device, via the device keyword."
217217
)
218218

219-
logger.info(f"Compiling with Settings:\n{settings}")
219+
# Ignore and warn about require_full_compilation flag
220+
if settings.require_full_compilation:
221+
logger.warning(
222+
"Detected require_full_compilation=True for a torch.compile run. "
223+
"This option has no effect in torch.compile."
224+
)
225+
settings.require_full_compilation = False
226+
227+
logger.info("Compilation Settings: %s\n", settings)
220228

221229
return settings
222230

tests/py/dynamo/partitioning/test_fast_partitioning.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,30 @@ def forward(self, x, y):
3131
"Single operators should not be segmented",
3232
)
3333

34+
def test_partition_fully_supported_one_op_require_full_compilation(self):
35+
class FullySupportedOneOp(torch.nn.Module):
36+
def __init__(self, *args, **kwargs) -> None:
37+
super().__init__(*args, **kwargs)
38+
39+
def forward(self, x, y):
40+
return torch.ops.aten.add.Tensor(x, y)
41+
42+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
43+
partitioned_graph = partitioning.fast_partition(
44+
deepcopy(fx_graph), require_full_compilation=True
45+
)
46+
self.assertEquals(
47+
len(
48+
[
49+
1
50+
for submod in list(partitioned_graph.named_children())
51+
if "_run_on_acc" in submod[0]
52+
]
53+
),
54+
1,
55+
"Single operators can be segmented if full compilation is required",
56+
)
57+
3458
def test_partition_fully_supported_multi_op(self):
3559
class FullySupportedMultiOp(torch.nn.Module):
3660
def __init__(self, *args, **kwargs) -> None:

tests/py/dynamo/partitioning/test_global_partitioning.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ def forward(self, x, y):
2525
"Single operators should not be segmented",
2626
)
2727

28+
def test_partition_fully_supported_one_op_require_full_compilation(self):
29+
class FullySupportedOneOp(torch.nn.Module):
30+
def __init__(self, *args, **kwargs) -> None:
31+
super().__init__(*args, **kwargs)
32+
33+
def forward(self, x, y):
34+
return torch.ops.aten.add.Tensor(x, y)
35+
36+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
37+
partitioned_graph = partitioning.global_partition(
38+
deepcopy(fx_graph), require_full_compilation=True
39+
)
40+
self.assertEquals(
41+
len(list(partitioned_graph.named_children())),
42+
1,
43+
"Single operators can be segmented if full compilation is required",
44+
)
45+
2846
def test_partition_fully_supported_multi_op(self):
2947
class FullySupportedMultiOp(torch.nn.Module):
3048
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)