Skip to content

fix: Raise error when registering Packet-keyed converter #2285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
cast,
)

from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.fx.converter_registry import CONVERTERS
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
Expand Down Expand Up @@ -101,6 +102,19 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat
capability_validator=capability_validator,
)

# OpOverloadPackets are only valid if they have a single overload, or
# only the ["default", "out"] overloads, due to PyTorch conventions
if isinstance(key, OpOverloadPacket) and (
len(key.overloads()) >= 3
or (len(key.overloads()) == 2 and "out" not in key.overloads())
):
raise AssertionError(
f"Detected converter for OpOverloadPacket {key}. "
"We do not support OpOverloadPacket-keyed converters with multiple overloads. "
"Make sure to explicitly specify each converter overload. For instance "
"aten.mean is not a valid key, but aten.mean.default is."
)

# If a converter for this operator already exists, append the new converter to the list
# Otherwise, start a new list
if key in DYNAMO_ATEN_CONVERTERS:
Expand Down
37 changes: 37 additions & 0 deletions tests/py/dynamo/backend/test_specialized_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,42 @@ def forward(self, x):
torch._dynamo.reset()


class TestPacketOperator(TestCase):
def test_packet_operator(self):
class PacketAsOperator(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.tanh(x)

# Operations expected to be removed in the traced graph
expected_ops = {torch.ops.aten.tanh.default}
unexpected_ops = {
torch.ops.aten.tanh,
}
Comment on lines +241 to +249
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validates that the Packet callable torch.ops.aten.tanh is transformed into torch.ops.aten.tanh.default by AOT autograd.


inputs = [torch.rand((3, 5)).cuda()]

fx_graph = torch.fx.symbolic_trace(PacketAsOperator())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()