Skip to content

Commit 43d21b2

Browse files
committed
fix: Repair input aliasing with clone insertion
1 parent a4056cc commit 43d21b2

File tree

10 files changed

+97
-61
lines changed

10 files changed

+97
-61
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering import (
1414
ATEN_LOWERING_PASSES,
1515
get_decompositions,
16-
replace_builtin_inplace_ops,
16+
repair_input_aliasing,
1717
)
1818
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1919
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
@@ -75,7 +75,7 @@ def _pretraced_backend(
7575
with unittest.mock.patch.object(
7676
fake_mode, "allow_non_fake_inputs", True
7777
), fake_mode:
78-
replace_builtin_inplace_ops(gm)
78+
repair_input_aliasing(gm)
7979

8080
# Invoke AOTAutograd to translate operators to aten
8181
gm = aot_export_joint_simple(

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5-
from ._replace_inplace_ops import replace_builtin_inplace_ops
5+
from ._repair_input_aliasing import repair_input_aliasing
66
from .passes import ATEN_LOWERING_PASSES
77
from .substitutions import * # noqa: F401
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Inserts clone operators temporary ahead of every placeholder
10+
11+
See: https://github.com/pytorch/pytorch/issues/108079
12+
Undone by `remove_input_alias_fixing_clones` after tracing
13+
"""
14+
for node in gm.graph.nodes:
15+
if node.op == "placeholder":
16+
# Insert clone for placeholder node to avoid
17+
# input aliasing or mutation
18+
with gm.graph.inserting_after(node):
19+
cloned_input = gm.graph.call_function(
20+
torch.ops.aten.clone.default,
21+
args=(node,),
22+
)
23+
24+
# Replace all uses of the placeholder except the cloned node
25+
# with the cloned placeholder
26+
node.replace_all_uses_with(
27+
cloned_input,
28+
delete_user_cb=lambda node: node != cloned_input,
29+
)
30+
31+
gm.graph.lint()
32+
gm.recompile()
33+
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")
34+
35+
return gm

py/torch_tensorrt/dynamo/lowering/_replace_inplace_ops.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from torch.fx.passes.pass_manager import PassManager
22

33
from .constant_folding import constant_fold
4+
from .pass_utils import clean_up_graph_after_modifications
5+
6+
# Import and order lowering passes
7+
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
48
from .repair_input_as_output import repair_input_as_output
59

610
ATEN_LOWERING_PASSES = PassManager.build_from_passlist(
711
[
12+
remove_input_alias_fixing_clones,
813
constant_fold,
914
repair_input_as_output,
1015
]

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
5+
from torch_tensorrt.dynamo.lowering.passes import clean_up_graph_after_modifications
56

67
logger = logging.getLogger(__name__)
78

@@ -30,9 +31,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3031
for node in erased_params:
3132
gm.graph.erase_node(node)
3233

33-
gm.graph.eliminate_dead_code()
34-
gm.graph.lint()
35-
gm.recompile()
34+
gm = clean_up_graph_after_modifications(gm)
3635

3736
logger.debug(f"Graph after constant folding:\n{gm.graph}")
3837

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
4+
def clean_up_graph_after_modifications(
5+
gm: torch.fx.GraphModule,
6+
) -> torch.fx.GraphModule:
7+
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
8+
gm.graph.eliminate_dead_code()
9+
gm.graph.lint()
10+
gm.recompile()
11+
return gm
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes import clean_up_graph_after_modifications
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
10+
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
11+
"""Remove the auxiliary clone nodes inserted to fix input aliasing
12+
13+
See: https://github.com/pytorch/pytorch/issues/108079
14+
"""
15+
modified_graph = False
16+
17+
for node in gm.graph.nodes:
18+
# If the node is a placeholder and its only user is a clone node
19+
# it was modified by the input alias-fixing pass, and the change
20+
# needs to be undone
21+
if (
22+
node.op == "placeholder"
23+
and len(node.users) == 1
24+
and list(node.users)[0].target == torch.ops.aten.clone.default
25+
):
26+
modified_graph = True
27+
28+
# Replace all uses of the clone with the placholder, delete the clone
29+
clone_node = list(node.users)[0]
30+
clone_node.replace_all_uses_with(node)
31+
gm.graph.erase_node(clone_node)
32+
33+
if modified_graph:
34+
gm = clean_up_graph_after_modifications(gm)
35+
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")
36+
37+
return gm

py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo.lowering.passes import clean_up_graph_after_modifications
45

56
logger = logging.getLogger(__name__)
67

@@ -37,9 +38,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3738
output.replace_input_with(placeholder, cloned_placeholder)
3839

3940
if modified_graph:
40-
gm.graph.eliminate_dead_code()
41-
gm.graph.lint()
42-
gm.recompile()
41+
gm = clean_up_graph_after_modifications(gm)
4342
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
4443

4544
return gm

tests/py/dynamo/testing_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch_tensorrt.dynamo.lowering import (
1111
ATEN_LOWERING_PASSES,
1212
get_decompositions,
13-
replace_builtin_inplace_ops,
13+
repair_input_aliasing,
1414
)
1515
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1616

@@ -43,7 +43,7 @@ def fx_dynamo_testing_backend(
4343
with unittest.mock.patch.object(
4444
fake_mode, "allow_non_fake_inputs", True
4545
), fake_mode:
46-
replace_builtin_inplace_ops(gm)
46+
repair_input_aliasing(gm)
4747

4848
# Invoke AOTAutograd to translate operators to aten
4949
gm = aot_export_joint_simple(

0 commit comments

Comments
 (0)