Skip to content

Commit 4f72425

Browse files
authored
feat: Add maxpool lowering passes and experimental folder in Dynamo (#2358)
1 parent a7f9055 commit 4f72425

File tree

16 files changed

+246
-489
lines changed

16 files changed

+246
-489
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
rev: 'v1.4.1'
4141
hooks:
4242
- id: mypy
43-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
43+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
# Ruff version.
4646
rev: v0.0.278

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
get_decompositions,
1616
repair_input_aliasing,
1717
)
18-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1918
from torch_tensorrt.dynamo.utils import (
2019
parse_dynamo_kwargs,
2120
prepare_inputs,
@@ -68,9 +67,6 @@ def _pretraced_backend(
6867
try:
6968
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))
7069

71-
# Perform Pre-AOT Lowering for Module-Level Replacement
72-
gm = pre_aot_substitutions(gm)
73-
7470
fake_mode = detect_fake_mode(sample_inputs)
7571

7672
# Place backend tracing within FakeTensor context allowing nonfake Tensors
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
from ._decompositions import get_decompositions # noqa: F401
22
from ._fusers import * # noqa: F401
3-
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
4-
from ._pre_aot_lowering import register_substitution # noqa: F401
53
from ._repair_input_aliasing import repair_input_aliasing
64
from .passes import apply_lowering_passes
7-
from .substitutions import * # noqa: F401

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

Lines changed: 0 additions & 145 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .pass_manager import DynamoPassManager
1010
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1111
from .repair_input_as_output import repair_input_as_output
12+
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1213

1314
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1415
[
@@ -17,6 +18,7 @@
1718
repair_input_as_output,
1819
lower_efficient_attention,
1920
fuse_prims_broadcast,
21+
replace_max_pool_with_indices,
2022
]
2123
)
2224

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import logging
2+
import operator
3+
from typing import Sequence
4+
5+
import torch
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def replace_max_pool_with_indices(
14+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15+
) -> torch.fx.GraphModule:
16+
"""Replace MaxPool nodes which return unused indices"""
17+
replacement_dict = {
18+
torch.ops.aten.max_pool1d_with_indices.default: torch.ops.aten.max_pool1d.default,
19+
torch.ops.aten.max_pool2d_with_indices.default: torch.ops.aten.max_pool2d.default,
20+
torch.ops.aten.max_pool3d_with_indices.default: torch.ops.aten.max_pool3d.default,
21+
}
22+
23+
modified_graph = False
24+
25+
for node in gm.graph.nodes:
26+
# If the node is a placeholder and its only user is a clone node
27+
# it was modified by the input alias-fixing pass, and the change
28+
# needs to be undone
29+
if (
30+
node.target in replacement_dict
31+
and len(node.users) == 1
32+
and list(node.users)[0].target == operator.getitem
33+
and list(node.users)[0].args[1] == 0
34+
):
35+
modified_graph = True
36+
37+
# Replace all uses of the clone with the placholder, delete the clone
38+
getitem_node = list(node.users)[0]
39+
40+
with gm.graph.inserting_after(getitem_node):
41+
maxpool_fused = gm.graph.call_function(
42+
replacement_dict[node.target],
43+
args=node.args,
44+
kwargs=node.kwargs,
45+
)
46+
47+
logger.debug(
48+
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "
49+
f"is the only user of placeholder {node} and was inserted by the compiler."
50+
)
51+
52+
getitem_node.replace_all_uses_with(maxpool_fused)
53+
gm.graph.erase_node(getitem_node)
54+
gm.graph.erase_node(node)
55+
56+
if modified_graph:
57+
gm = clean_up_graph_after_modifications(gm)
58+
logger.debug(f"Graph after fusing maxpool operators with indices:\n{gm.graph}")
59+
60+
return gm

py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)