Skip to content

Commit 90eec47

Browse files
committed
feat: Refactor pass manager and utilities
- Improve logging and pass manager utilities - Add testing of new utilities
1 parent 7ebadcd commit 90eec47

File tree

4 files changed

+130
-6
lines changed

4 files changed

+130
-6
lines changed

examples/dynamo/dynamo_aten_lowering_passes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,24 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
8686
# %%
8787
from torch_tensorrt.dynamo.lowering.passes import add_lowering_pass
8888

89+
# Adds the lowering pass at the end of the pass list
8990
add_lowering_pass(repair_input_as_output)
9091

92+
# Alternatively, specify an index to insert the lowering pass at a specific location
93+
add_lowering_pass(repair_input_as_output, 1)
94+
95+
# To remove a lowering pass, specify the index of the pass to remove:
96+
from torch_tensorrt.dynamo.lowering.passes import remove_lowering_pass
97+
98+
# Removes the lowering pass at index 1
99+
remove_lowering_pass(1)
100+
101+
102+
# To view all lowering passes, in the order they will be run, use the following
103+
from torch_tensorrt.dynamo.lowering.passes import dump_lowering_passes
104+
105+
print(dump_lowering_passes())
106+
91107
# %%
92108
# 3. Apply Available Lowering Passes
93109
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,55 @@
1-
from typing import Callable
1+
import logging
2+
from typing import Callable, Optional
23

34
import torch
4-
from torch.fx.passes.pass_manager import PassManager
55

6+
# Import and order lowering passes and pass manager
67
from .constant_folding import constant_fold
8+
from .pass_manager import DynamoPassManager
79
from .repair_input_as_output import repair_input_as_output
810

9-
ATEN_LOWERING_PASSES = PassManager.build_from_passlist(
11+
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1012
[
1113
constant_fold,
1214
repair_input_as_output,
1315
]
1416
)
1517

18+
logger = logging.getLogger(__name__)
19+
1620

1721
def add_lowering_pass(
18-
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
22+
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
23+
index: Optional[int] = None,
1924
) -> None:
20-
"""Adds a lowering pass to the registry"""
21-
ATEN_LOWERING_PASSES.add_pass(lowering_pass)
25+
"""Adds a lowering pass to the registry, at a specified index if desired
26+
27+
If no index is specified, the lowering pass is inserted at the end of the list
28+
"""
29+
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
30+
logger.debug(
31+
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
32+
)
33+
return
34+
35+
36+
def remove_lowering_pass(index: int) -> None:
37+
"""Removes a lowering pass at a specific index from the registry"""
38+
ATEN_LOWERING_PASSES.remove_pass_with_index(index)
39+
logger.debug(
40+
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}"
41+
)
2242
return
2343

2444

2545
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
2646
"""Applies the lowering passes to a graph module, returns the modified GraphModule"""
47+
logging.debug(
48+
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}"
49+
)
2750
return ATEN_LOWERING_PASSES(gm)
51+
52+
53+
def dump_lowering_passes() -> str:
54+
"""Returns a string containing the lowering passes"""
55+
return str(ATEN_LOWERING_PASSES)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, Callable, List, Optional
2+
3+
import torch
4+
from torch.fx.passes.pass_manager import PassManager
5+
6+
7+
class DynamoPassManager(PassManager): # type: ignore[misc]
8+
def __init__(
9+
self,
10+
passes: Optional[
11+
List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]
12+
] = None,
13+
):
14+
super().__init__(passes)
15+
16+
@classmethod
17+
def build_from_passlist(
18+
cls,
19+
passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]],
20+
) -> Any:
21+
pm = DynamoPassManager(passes)
22+
return pm
23+
24+
def add_pass_with_index(
25+
self,
26+
lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule],
27+
index: Optional[int] = None,
28+
) -> None:
29+
if index is None:
30+
self.passes.append(lowering_pass)
31+
index = -1
32+
else:
33+
self.passes.insert(index, lowering_pass)
34+
35+
def remove_pass_with_index(self, index: int) -> None:
36+
del self.passes[index]
37+
38+
def __call__(self, source: Any) -> Any:
39+
return super().__call__(source)
40+
41+
def __str__(self) -> str:
42+
return str(self.passes)

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,43 @@ def forward(self, x, y):
5555
torch._dynamo.reset()
5656

5757

58+
class TestLoweringPassMembership(TestCase):
59+
def insert_at_end(self):
60+
from torch_tensorrt.dynamo.lowering.passes import (
61+
ATEN_LOWERING_PASSES,
62+
add_lowering_pass,
63+
remove_lowering_pass,
64+
)
65+
66+
def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
67+
return gm
68+
69+
add_lowering_pass(identity_pass)
70+
71+
self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1])
72+
73+
remove_lowering_pass(-1)
74+
75+
self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
76+
77+
def insert_at_index(self):
78+
from torch_tensorrt.dynamo.lowering.passes import (
79+
ATEN_LOWERING_PASSES,
80+
add_lowering_pass,
81+
remove_lowering_pass,
82+
)
83+
84+
def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
85+
return gm
86+
87+
add_lowering_pass(identity_pass, 0)
88+
89+
self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0])
90+
91+
remove_lowering_pass(0)
92+
93+
self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes)
94+
95+
5896
if __name__ == "__main__":
5997
run_tests()

0 commit comments

Comments
 (0)