Skip to content

Commit ff4d940

Browse files
authored
feat: Add ATen lowering pass system (#2280)
1 parent 0a939df commit ff4d940

File tree

13 files changed

+448
-53
lines changed

13 files changed

+448
-53
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
.. _writing_dynamo_aten_lowering_passes:
2+
3+
Writing Dynamo ATen Lowering Passes
4+
===================
5+
6+
Basics of a Lowering Pass
7+
------------
8+
9+
ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object.
10+
11+
Lowering Pass Requirements
12+
------------
13+
14+
An ATen lowering pass function in Torch-TRT must satisfy two requirements:
15+
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule`
16+
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation
17+
18+
See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines.
19+
20+
Example Lowering Pass
21+
------------
22+
23+
.. code-block:: python
24+
25+
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
26+
"""Repair scenarios where inputs are also outputs of the graph
27+
28+
TRT does not allow such cases, so we insert a clone (identity) layer
29+
"""
30+
modified_graph = False
31+
32+
# Extract graph placeholder Tensors
33+
placeholders = [
34+
node
35+
for node in gm.graph.nodes
36+
if (
37+
node.op == "placeholder"
38+
and isinstance(node.type, type)
39+
and issubclass(node.type, torch.Tensor)
40+
)
41+
]
42+
43+
for placeholder in placeholders:
44+
# If any placeholder has any users which are direct graph outputs
45+
if len(placeholder.users) >= 1 and any(
46+
user.op == "output" for user in placeholder.users
47+
):
48+
modified_graph = True
49+
50+
# Get direct graph outputs which are direct uses of placeholders
51+
direct_outputs = [user for user in placeholder.users if user.op == "output"]
52+
53+
# Insert clone node for placeholder to ensure
54+
# placeholder is not a direct output
55+
with gm.graph.inserting_after(placeholder):
56+
cloned_placeholder = gm.graph.call_function(
57+
torch.ops.aten.clone.default,
58+
args=(placeholder,),
59+
)
60+
61+
# Replace placeholder as output with cloned version
62+
for output in direct_outputs:
63+
output.replace_input_with(placeholder, cloned_placeholder)
64+
65+
# If the graph was modified, clean up the graph and ensure it is up-to-date
66+
if modified_graph:
67+
gm.graph.eliminate_dead_code()
68+
gm.graph.lint()
69+
gm.recompile()
70+
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
71+
72+
return gm
73+
74+
75+
Registering Lowering Passes
76+
----------------------
77+
78+
Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted.
79+
80+
For instance, to insert the pass at the default location (end of the list), the following code can be used:
81+
82+
.. code-block:: python
83+
84+
@_aten_lowering_pass
85+
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
86+
...
87+
88+
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used:
89+
90+
.. code-block:: python
91+
92+
@_aten_lowering_pass(index=0)
93+
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
94+
...
95+
96+
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index.
97+
98+
.. code-block:: python
99+
100+
# Print all lowering passes in the list
101+
print(dump_lowering_passes())
102+
103+
# Apply lowering passes to a GraphModule
104+
apply_lowering_passes(graph_module)
105+
106+
# Remove the lowering pass at index 1
107+
_remove_lowering_pass(index=1)
108+
109+
**Note:** The above APIs are subject to change, as the lowering pass system evolves.

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ Contributor Documentation
128128
--------------------------------
129129
* :ref:`system_overview`
130130
* :ref:`writing_converters`
131+
* :ref:`writing_dynamo_aten_lowering_passes`
131132
* :ref:`useful_links`
132133

133134
.. toctree::
@@ -137,6 +138,7 @@ Contributor Documentation
137138

138139
contributors/system_overview
139140
contributors/writing_converters
141+
contributors/writing_dynamo_aten_lowering_passes
140142
contributors/useful_links
141143

142144
Indices

py/torch_tensorrt/dynamo/aten_tracer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import torch
88
from torch._export import export
9-
from torch_tensorrt.dynamo.backend.backends import constant_fold
10-
from torch_tensorrt.dynamo.lowering import get_decompositions
9+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1110
from torch_tensorrt.dynamo.utils import set_log_level
1211

1312
logger = logging.getLogger(__name__)
@@ -29,6 +28,6 @@ def trace(
2928
"torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions)
3029
):
3130
graph_module = export(model, tuple(inputs)).module()
32-
constant_fold(graph_module)
31+
graph_module = apply_lowering_passes(graph_module)
3332
logger.debug("Post export graph: " + str(graph_module.graph))
3433
return graph_module

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,12 @@
1010
from torch._dynamo.utils import detect_fake_mode
1111
from torch._functorch.aot_autograd import _aot_export_function
1212
from torch._ops import OpOverload
13-
from torch_tensorrt._utils import sanitized_torch_version
1413
from torch_tensorrt.dynamo import CompilationSettings
1514
from torch_tensorrt.dynamo.compile import compile_module
16-
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
15+
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
1716
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1817
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level
1918

20-
from packaging import version
21-
22-
# Modify import location of utilities based on Torch version
23-
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
24-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
25-
else:
26-
from torch._inductor.constant_folding import (
27-
ConstantFolder,
28-
replace_node_with_constant,
29-
)
30-
3119
logger = logging.getLogger(__name__)
3220

3321

@@ -84,7 +72,7 @@ def _pretraced_backend(
8472
fake_mode, "allow_non_fake_inputs", True
8573
), fake_mode:
8674
# Invoke AOTAutograd to translate operators to aten
87-
graph_module = aot_export_for_compile(
75+
gm = aot_export_for_compile(
8876
gm,
8977
sample_inputs,
9078
decompositions=get_decompositions(
@@ -94,10 +82,10 @@ def _pretraced_backend(
9482

9583
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
9684

97-
constant_fold(graph_module)
85+
gm = apply_lowering_passes(gm)
9886

9987
trt_compiled = compile_module(
100-
graph_module,
88+
gm,
10189
sample_inputs,
10290
settings=settings,
10391
)
@@ -121,35 +109,6 @@ def _pretraced_backend(
121109
raise
122110

123111

124-
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
125-
def constant_fold(gm: torch.fx.GraphModule) -> Any:
126-
"""Adapted from:
127-
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
128-
129-
Folds constants in the graph module, not skipping constructors
130-
131-
Modifies the graph in-place and replaces node with constants
132-
"""
133-
cf = ConstantFolder(gm, skip_constructors=False)
134-
cf.run()
135-
136-
for node, constant in cf.node_replacements.items():
137-
replace_node_with_constant(gm, node, constant)
138-
139-
erased_params = []
140-
for node in gm.graph.nodes:
141-
if node.op == "get_attr" and len(node.users) == 0:
142-
delattr(gm, node.target)
143-
erased_params.append(node)
144-
145-
for node in erased_params:
146-
gm.graph.erase_node(node)
147-
148-
gm.graph.eliminate_dead_code()
149-
gm.graph.lint()
150-
gm.recompile()
151-
152-
153112
def aot_export_for_compile(
154113
func: torch.fx.GraphModule,
155114
args: Sequence[torch.Tensor],

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from .passes import apply_lowering_passes
56
from .substitutions import * # noqa: F401
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._aten_lowering_pass import *
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import logging
2+
from typing import Callable, Optional
3+
4+
import torch
5+
6+
from .constant_folding import constant_fold
7+
from .pass_manager import DynamoPassManager
8+
from .repair_input_as_output import repair_input_as_output
9+
10+
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
11+
[
12+
constant_fold,
13+
repair_input_as_output,
14+
]
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
21+
22+
23+
def _aten_lowering_pass(
24+
*args: LoweringPassSignature,
25+
index: Optional[int] = None,
26+
) -> LoweringPassSignature:
27+
"""Adds a lowering pass to the registry, at a specified index if desired
28+
29+
If no index is specified, the lowering pass is inserted at the end of the list
30+
"""
31+
32+
def add_lowering_pass(
33+
lowering_pass: LoweringPassSignature,
34+
) -> LoweringPassSignature:
35+
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
36+
logger.debug(
37+
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
38+
)
39+
return lowering_pass
40+
41+
# If there are arguments specified, the decorator may have been called as-is
42+
if args:
43+
# The decorator may only be called with the lowering pass
44+
# The index must be specified as a keyword argument
45+
if len(args) == 1 and callable(args[0]):
46+
return add_lowering_pass(args[0])
47+
else:
48+
raise AssertionError(
49+
f"aten_lowering_pass decorator called with invalid arguments {args} "
50+
"To specify an index to insert the pass, use the keyword 'index='"
51+
)
52+
# If no arguments are specified, the decorator was called with an index keyword
53+
else:
54+
return add_lowering_pass
55+
56+
57+
def _remove_lowering_pass(*, index: int) -> None:
58+
"""Removes a lowering pass at a specific index from the registry"""
59+
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
60+
logger.debug(
61+
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
62+
)
63+
return
64+
65+
66+
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
67+
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
68+
logging.debug(
69+
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
70+
)
71+
return ATEN_LOWERING_PASSES(gm)
72+
73+
74+
def dump_lowering_passes() -> str:
75+
"""Returns a string containing the lowering passes"""
76+
return str(ATEN_LOWERING_PASSES)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt._utils import sanitized_torch_version
5+
6+
from packaging import version
7+
8+
# Modify import location of utilities based on Torch version
9+
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
10+
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
11+
else:
12+
from torch._inductor.constant_folding import (
13+
ConstantFolder,
14+
replace_node_with_constant,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
21+
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
22+
"""Adapted from:
23+
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
24+
25+
Folds constants in the graph module, not skipping constructors
26+
27+
Modifies the graph in-place and replaces node with constants
28+
"""
29+
cf = ConstantFolder(gm, skip_constructors=False)
30+
cf.run()
31+
32+
for node, constant in cf.node_replacements.items():
33+
replace_node_with_constant(gm, node, constant)
34+
35+
erased_params = []
36+
for node in gm.graph.nodes:
37+
# If get_attr node has no users, mark it for deletion
38+
if node.op == "get_attr" and len(node.users) == 0:
39+
# If the node's parameter is not a parameter of any other node, remove it
40+
if not any(
41+
other.target == node.target for other in gm.graph.nodes if other != node
42+
):
43+
delattr(gm, node.target)
44+
erased_params.append(node)
45+
46+
# Remove unused nodes from the graph
47+
for node in erased_params:
48+
gm.graph.erase_node(node)
49+
50+
gm.graph.eliminate_dead_code()
51+
gm.graph.lint()
52+
gm.recompile()
53+
54+
logger.debug(f"Graph after constant folding:\n{gm.graph}")
55+
56+
return gm

0 commit comments

Comments
 (0)