Skip to content

Commit b50290d

Browse files
authored
fix: Raise error when registering Packet-keyed converter (#2285)
1 parent ac007ce commit b50290d

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

py/torch_tensorrt/dynamo/conversion/converter_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cast,
1717
)
1818

19+
from torch._ops import OpOverloadPacket
1920
from torch.fx.node import Argument, Node, Target, _get_qualified_name
2021
from torch_tensorrt.fx.converter_registry import CONVERTERS
2122
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
@@ -101,6 +102,19 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
101102
capability_validator=capability_validator,
102103
)
103104

105+
# OpOverloadPackets are only valid if they have a single overload, or
106+
# only the ["default", "out"] overloads, due to PyTorch conventions
107+
if isinstance(key, OpOverloadPacket) and (
108+
len(key.overloads()) >= 3
109+
or (len(key.overloads()) == 2 and "out" not in key.overloads())
110+
):
111+
raise AssertionError(
112+
f"Detected converter for OpOverloadPacket {key}. "
113+
"We do not support OpOverloadPacket-keyed converters with multiple overloads. "
114+
"Make sure to explicitly specify each converter overload. For instance "
115+
"aten.mean is not a valid key, but aten.mean.default is."
116+
)
117+
104118
# If a converter for this operator already exists, append the new converter to the list
105119
# Otherwise, start a new list
106120
if key in DYNAMO_ATEN_CONVERTERS:

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,5 +236,42 @@ def forward(self, x):
236236
torch._dynamo.reset()
237237

238238

239+
class TestPacketOperator(TestCase):
240+
def test_packet_operator(self):
241+
class PacketAsOperator(torch.nn.Module):
242+
def forward(self, x):
243+
return torch.ops.aten.tanh(x)
244+
245+
# Operations expected to be removed in the traced graph
246+
expected_ops = {torch.ops.aten.tanh.default}
247+
unexpected_ops = {
248+
torch.ops.aten.tanh,
249+
}
250+
251+
inputs = [torch.rand((3, 5)).cuda()]
252+
253+
fx_graph = torch.fx.symbolic_trace(PacketAsOperator())
254+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
255+
fx_graph,
256+
inputs,
257+
expected_ops=expected_ops,
258+
unexpected_ops=unexpected_ops,
259+
min_block_size=1,
260+
)
261+
262+
self.assertEquals(
263+
len(unexpected_ops_seen),
264+
0,
265+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
266+
)
267+
268+
self.assertEquals(
269+
len(expected_ops_unseen),
270+
0,
271+
f"The following expected ops were not encountered: {expected_ops_unseen}",
272+
)
273+
torch._dynamo.reset()
274+
275+
239276
if __name__ == "__main__":
240277
run_tests()

0 commit comments

Comments
 (0)