Skip to content

Commit 8f39c94

Browse files
committed
fix: Raise error when registering Packet-keyed converter
- Packet-keyed converters are too error-prone, since new overloads could be added or existing overloads may have been overlooked
1 parent 8c92918 commit 8f39c94

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cast,
1717
)
1818

19+
from torch._ops import OpOverloadPacket
1920
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2021
from torch_tensorrt.fx.converter_registry import CONVERTERS
2122
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
@@ -101,6 +102,19 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
101102
capability_validator=capability_validator,
102103
)
103104

105+
# OpOverloads are only valid if they have a single overload, or
106+
# only the ["default", "out"] overloads, due to PyTorch conventions
107+
if isinstance(key, OpOverloadPacket) and (
108+
len(key.overloads()) >= 3
109+
or (len(key.overloads()) == 2 and "out" not in key.overloads())
110+
):
111+
raise AssertionError(
112+
f"Detected converter for OpOverloadPacket {key}. "
113+
"We do not support OpOverloadPacket-keyed converters with multiple overloads. "
114+
"Make sure to explicitly specify each converter overload. For instance "
115+
"aten.mean is not a valid key, but aten.mean.default is."
116+
)
117+
104118
# If a converter for this operator already exists, append the new converter to the list
105119
# Otherwise, start a new list
106120
if key in DYNAMO_ATEN_CONVERTERS:

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,41 @@ def forward(self, x):
157157
torch._dynamo.reset()
158158

159159

160+
class TestPacketOperator(TestCase):
161+
def test_packet_operator(self):
162+
class PacketAsOperator(torch.nn.Module):
163+
def forward(self, x):
164+
return torch.ops.aten.tanh(x)
165+
166+
# Operations expected to be removed in the traced graph
167+
expected_ops = {torch.ops.aten.tanh.default}
168+
unexpected_ops = {
169+
torch.ops.aten.tanh,
170+
}
171+
172+
inputs = [torch.rand((3, 5)).cuda()]
173+
174+
fx_graph = torch.fx.symbolic_trace(PacketAsOperator())
175+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
176+
fx_graph,
177+
inputs,
178+
expected_ops=expected_ops,
179+
unexpected_ops=unexpected_ops,
180+
min_block_size=1,
181+
)
182+
183+
self.assertEquals(
184+
len(unexpected_ops_seen),
185+
0,
186+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
187+
)
188+
189+
self.assertEquals(
190+
len(expected_ops_unseen),
191+
0,
192+
f"The following expected ops were not encountered: {expected_ops_unseen}",
193+
)
194+
195+
160196
if __name__ == "__main__":
161197
run_tests()

0 commit comments

Comments
 (0)