Skip to content

Commit a6537cf

Browse files
committed
fix: Repair usage of torch_executed_ops
- Previously, `torch_executed_ops` were excluded at partitioning time, but not conversion time, causing a bug with obscure usages of `getitem` - Now, `torch_executed_ops` are excluded at partitioning time and their converters are explicitly disabled
1 parent 128dd65 commit a6537cf

File tree

7 files changed

+61
-27
lines changed

7 files changed

+61
-27
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import collections.abc
44
import logging
5-
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
5+
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
88
from torch.export import ExportedProgram
9+
from torch.fx.node import Target
910
from torch_tensorrt._Device import Device
1011
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1112
EngineCapability,
@@ -49,6 +50,9 @@
4950
convert_module,
5051
repair_long_or_double_inputs,
5152
)
53+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
54+
DYNAMO_CONVERTERS as CONVERTERS,
55+
)
5256
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
5357
from torch_tensorrt.dynamo.utils import (
5458
get_torch_inputs,
@@ -85,7 +89,7 @@ def compile(
8589
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
8690
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
8791
min_block_size: int = MIN_BLOCK_SIZE,
88-
torch_executed_ops: Optional[List[str]] = None,
92+
torch_executed_ops: Optional[Collection[Target]] = None,
8993
torch_executed_modules: Optional[List[str]] = None,
9094
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
9195
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
@@ -143,7 +147,7 @@ def compile(
143147
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
144148
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
145149
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
146-
torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
150+
torch_executed_ops (Collection[Target]): Set of aten operators that must be run in PyTorch. An error will be thrown if this set is not empty but ``require_full_compilation`` is True
147151
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
148152
pass_through_build_failures (bool): Error out if there are issues during compilation (only applicable to torch.compile workflows)
149153
max_aux_stream (Optional[int]): Maximum streams in the engine
@@ -212,7 +216,7 @@ def compile(
212216
"min_block_size": min_block_size,
213217
"torch_executed_ops": torch_executed_ops
214218
if torch_executed_ops is not None
215-
else [],
219+
else set(),
216220
"pass_through_build_failures": pass_through_build_failures,
217221
"max_aux_streams": max_aux_streams,
218222
"version_compatible": version_compatible,
@@ -256,6 +260,9 @@ def compile_module(
256260
"""
257261
dryrun_tracker = DryRunTracker()
258262

263+
# Set torch-executed ops
264+
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
265+
259266
# Check the number of supported operations in the graph
260267
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
261268
gm, settings.debug, settings.torch_executed_ops

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import dataclass, field
2-
from typing import Optional, Set, Union
2+
from typing import Collection, Optional, Union
33

44
import torch
55
from tensorrt import EngineCapability
6+
from torch.fx.node import Target
67
from torch_tensorrt._Device import Device
78
from torch_tensorrt.dynamo._defaults import (
89
DEBUG,
@@ -41,7 +42,7 @@ class CompilationSettings:
4142
debug (bool): Whether to print out verbose debugging information
4243
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
4344
min_block_size (int): Minimum number of operators per TRT-Engine Block
44-
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
45+
torch_executed_ops (Collection[Target]): Collection of operations to run in Torch, regardless of converter coverage
4546
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
4647
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
4748
version_compatible (bool): Provide version forward-compatibility for engine plan files
@@ -75,7 +76,7 @@ class CompilationSettings:
7576
debug: bool = DEBUG
7677
workspace_size: int = WORKSPACE_SIZE
7778
min_block_size: int = MIN_BLOCK_SIZE
78-
torch_executed_ops: Set[str] = field(default_factory=set)
79+
torch_executed_ops: Collection[Target] = field(default_factory=set)
7980
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
8081
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
8182
version_compatible: bool = VERSION_COMPATIBLE

py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import (
77
Any,
88
Callable,
9+
Collection,
910
Dict,
1011
List,
1112
Optional,
@@ -212,8 +213,16 @@ def __init__(
212213
CallingConvention.CTX for _ in range(len(self.registries))
213214
]
214215

216+
self.disallowed_targets: Collection[Target] = set()
217+
215218
self.validate_invariants()
216219

220+
def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
221+
self.disallowed_targets = torch_executed_ops
222+
223+
def get_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
224+
self.disallowed_targets = torch_executed_ops
225+
217226
def validate_invariants(self) -> None:
218227
"""Validates the invariants required of the dictionaries in the registries
219228
@@ -253,6 +262,14 @@ def __getitem_without_validation__(
253262

254263
self.validate_invariants()
255264

265+
if (
266+
key in self.disallowed_targets
267+
or self.qualified_name_or_str(key) in self.disallowed_targets
268+
):
269+
raise KeyError(
270+
f"A converter exists for {key}, but it was " "explicitly disallowed"
271+
)
272+
256273
# Iterate over all registries and return the first converter found
257274
for registry, calling_convention in zip(
258275
self.registries, self.registry_calling_conventions
@@ -288,6 +305,14 @@ def __getitem__(
288305
self.validate_invariants()
289306
key = node.target
290307

308+
if (
309+
key in self.disallowed_targets
310+
or self.qualified_name_or_str(key) in self.disallowed_targets
311+
):
312+
raise KeyError(
313+
f"A converter exists for {key}, but it was " "explicitly disallowed"
314+
)
315+
291316
# Iterate over all registries, validating the converter on the input node
292317
# If no capability_validator function is found, assume full coverage
293318
for registry, calling_convention in zip(

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ def is_node_supported(
4242
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4343

4444
if (
45-
node in CONVERTERS or node.op == "get_attr"
46-
) and node_name not in self.torch_executed_ops:
45+
(node in CONVERTERS or node.op == "get_attr")
46+
and node_name not in self.torch_executed_ops
47+
and node.target not in self.torch_executed_ops
48+
):
4749
# If node is a proper, supported computational node, store the operator
4850
if not node.is_impure() and node.op != "get_attr":
4951
if node_name not in self.supported_operators:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2-
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple
2+
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple
33

44
import torch
55
from torch.fx.graph_module import GraphModule
6+
from torch.fx.node import Target
67
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
78
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
89
from torch_tensorrt.dynamo._defaults import (
@@ -133,25 +134,25 @@ class TorchTensorRTOperatorSupport(OperatorSupport): # type: ignore[misc]
133134
def __init__(
134135
self,
135136
support_dict: Optional[SupportDict] = None,
136-
torch_executed_ops: Optional[Set[str]] = None,
137+
torch_executed_ops: Collection[Target] = set(),
137138
):
138139
super().__init__(support_dict)
139140

140141
# Initialize sets of supported/unsupported operators
141142
self.supported_operators: Dict[str, int] = {}
142143
self.unsupported_operators: Dict[str, int] = {}
143-
self.torch_executed_ops: Set[str] = (
144-
torch_executed_ops if torch_executed_ops is not None else set()
145-
)
144+
self.torch_executed_ops: Collection[Target] = torch_executed_ops
146145

147146
def is_node_supported(
148147
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
149148
) -> bool:
150149
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151150

152151
if (
153-
node in CONVERTERS or node.op == "get_attr"
154-
) and node_name not in self.torch_executed_ops:
152+
(node in CONVERTERS or node.op == "get_attr")
153+
and node_name not in self.torch_executed_ops
154+
and node.target not in self.torch_executed_ops
155+
):
155156
# If node is a proper, supported computational node, store the operator
156157
if not node.is_impure() and node.op != "get_attr":
157158
if node_name not in self.supported_operators:
@@ -201,7 +202,7 @@ def partition(
201202
gm: torch.fx.GraphModule,
202203
verbose: bool = DEBUG,
203204
min_block_size: int = MIN_BLOCK_SIZE,
204-
torch_executed_ops: Optional[Set[str]] = None,
205+
torch_executed_ops: Collection[Target] = set(),
205206
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
206207
) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]:
207208
"""Partition an FX GraphModule with aten ops into TRT engines
@@ -211,16 +212,12 @@ def partition(
211212
gm: FX GraphModule to partition
212213
verbose: Bool representing whether to print operator support
213214
min_block_size: Minimum number of operators per TRT-Engine Block
214-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
215+
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
215216
require_full_compilation: Whether to require that all operators be run in TRT
216217
Returns:
217218
torch.fx.GraphModule, TorchTensorRTOperatorSupport
218219
"""
219-
supported_ops = TorchTensorRTOperatorSupport(
220-
torch_executed_ops=torch_executed_ops
221-
if torch_executed_ops is not None
222-
else set()
223-
)
220+
supported_ops = TorchTensorRTOperatorSupport(torch_executed_ops=torch_executed_ops)
224221
partitioner = TRTPartitioner(
225222
gm,
226223
supported_ops,

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import pytest
44
import timm
55
import torch
6-
import torch_tensorrt as torchtrt
76
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
87

8+
import torch_tensorrt as torchtrt
9+
910
assertions = unittest.TestCase()
1011

1112

@@ -97,7 +98,7 @@ def forward(self, x):
9798
"ir": ir,
9899
"pass_through_build_failures": True,
99100
"optimization_level": 1,
100-
"torch_executed_ops": "torch.ops.aten.abs.default",
101+
"torch_executed_ops": {"torch.ops.aten.abs.default"},
101102
"min_block_size": 1,
102103
}
103104

tests/py/dynamo/models/test_export_serde.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import pytest
44
import timm
55
import torch
6-
import torch_tensorrt as torchtrt
76
import torchvision.models as models
87
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
98

9+
import torch_tensorrt as torchtrt
10+
1011
assertions = unittest.TestCase()
1112

1213

@@ -206,7 +207,7 @@ def forward(self, x):
206207
],
207208
"ir": ir,
208209
"min_block_size": 1,
209-
"torch_executed_ops": "torch.ops.aten.relu.default",
210+
"torch_executed_ops": {"torch.ops.aten.relu.default"},
210211
}
211212

212213
exp_program = torchtrt.dynamo.trace(model, **compile_spec)

0 commit comments

Comments
 (0)