Skip to content

Commit f23cbb7

Browse files
committed
fix: Repair input aliasing with clone insertion
1 parent 5a5e235 commit f23cbb7

File tree

11 files changed

+114
-65
lines changed

11 files changed

+114
-65
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering import (
1414
apply_lowering_passes,
1515
get_decompositions,
16-
replace_builtin_inplace_ops,
16+
repair_input_aliasing,
1717
)
1818
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1919
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
@@ -75,12 +75,13 @@ def _pretraced_backend(
7575
with unittest.mock.patch.object(
7676
fake_mode, "allow_non_fake_inputs", True
7777
), fake_mode:
78-
replace_builtin_inplace_ops(gm)
78+
repair_input_aliasing(gm)
7979

8080
# Invoke AOTAutograd to translate operators to aten
8181
gm = aot_export_joint_simple(
8282
gm,
8383
sample_inputs,
84+
trace_joint=False,
8485
decompositions=get_decompositions(
8586
settings.enable_experimental_decompositions
8687
),

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5-
from ._replace_inplace_ops import replace_builtin_inplace_ops
5+
from ._repair_input_aliasing import repair_input_aliasing
66
from .passes import add_lowering_pass, apply_lowering_passes
77
from .substitutions import * # noqa: F401
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Inserts clone operators temporarily ahead of every placeholder
10+
11+
See: https://github.com/pytorch/pytorch/issues/108079
12+
Undone by `remove_input_alias_fixing_clones` after tracing
13+
"""
14+
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
15+
16+
for node in placeholders:
17+
# Insert clones for placeholder nodes to avoid
18+
# input aliasing or mutation
19+
with gm.graph.inserting_after(placeholders[-1]):
20+
cloned_input = gm.graph.call_function(
21+
torch.ops.aten.clone.default,
22+
args=(node,),
23+
)
24+
25+
# Replace all uses of the placeholder except the cloned node
26+
# with the cloned placeholder
27+
node.replace_all_uses_with(
28+
cloned_input,
29+
delete_user_cb=lambda node: node != cloned_input,
30+
)
31+
32+
gm.graph.lint()
33+
gm.recompile()
34+
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")
35+
36+
return gm

py/torch_tensorrt/dynamo/lowering/_replace_inplace_ops.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from torch.fx.passes.pass_manager import PassManager
55

66
from .constant_folding import constant_fold
7+
8+
# Import and order lowering passes
9+
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
710
from .repair_input_as_output import repair_input_as_output
811

912
ATEN_LOWERING_PASSES = PassManager.build_from_passlist(
1013
[
14+
remove_input_alias_fixing_clones,
1115
constant_fold,
1216
repair_input_as_output,
1317
]

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import torch
44
from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
58

69
logger = logging.getLogger(__name__)
710

@@ -30,9 +33,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3033
for node in erased_params:
3134
gm.graph.erase_node(node)
3235

33-
gm.graph.eliminate_dead_code()
34-
gm.graph.lint()
35-
gm.recompile()
36+
gm = clean_up_graph_after_modifications(gm)
3637

3738
logger.debug(f"Graph after constant folding:\n{gm.graph}")
3839

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
4+
def clean_up_graph_after_modifications(
5+
gm: torch.fx.GraphModule,
6+
) -> torch.fx.GraphModule:
7+
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
8+
gm.graph.eliminate_dead_code()
9+
gm.graph.lint()
10+
gm.recompile()
11+
return gm
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
)
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
12+
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
"""Remove the auxiliary clone nodes inserted to fix input aliasing
14+
15+
See: https://github.com/pytorch/pytorch/issues/108079
16+
"""
17+
modified_graph = False
18+
19+
for node in gm.graph.nodes:
20+
# If the node is a placeholder and its only user is a clone node
21+
# it was modified by the input alias-fixing pass, and the change
22+
# needs to be undone
23+
if (
24+
node.op == "placeholder"
25+
and len(node.users) == 1
26+
and list(node.users)[0].target == torch.ops.aten.clone.default
27+
):
28+
modified_graph = True
29+
30+
# Replace all uses of the clone with the placholder, delete the clone
31+
clone_node = list(node.users)[0]
32+
clone_node.replace_all_uses_with(node)
33+
gm.graph.erase_node(clone_node)
34+
35+
if modified_graph:
36+
gm = clean_up_graph_after_modifications(gm)
37+
logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}")
38+
39+
return gm

py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
)
47

58
logger = logging.getLogger(__name__)
69

@@ -37,9 +40,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3740
output.replace_input_with(placeholder, cloned_placeholder)
3841

3942
if modified_graph:
40-
gm.graph.eliminate_dead_code()
41-
gm.graph.lint()
42-
gm.recompile()
43+
gm = clean_up_graph_after_modifications(gm)
4344
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
4445

4546
return gm

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,12 @@ def forward(self, x):
280280

281281
def test_input_modifications_mul(self):
282282
class InplaceMul(torch.nn.Module):
283-
def forward(self, x):
283+
def forward(self, x, y):
284284
x *= 5.0
285285
x *= 1.9
286-
y = x + 1
287-
y /= 1.3
288-
return y
286+
z = x + y
287+
z /= 1.3
288+
return z
289289

290290
inputs = [
291291
torch.rand(
@@ -294,6 +294,12 @@ def forward(self, x):
294294
5,
295295
7,
296296
).cuda(),
297+
torch.rand(
298+
1,
299+
3,
300+
5,
301+
7,
302+
).cuda(),
297303
]
298304

299305
fx_graph = torch.fx.symbolic_trace(InplaceMul())

tests/py/dynamo/testing_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch_tensorrt.dynamo.lowering import (
1111
apply_lowering_passes,
1212
get_decompositions,
13-
replace_builtin_inplace_ops,
13+
repair_input_aliasing,
1414
)
1515
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1616

@@ -43,7 +43,7 @@ def fx_dynamo_testing_backend(
4343
with unittest.mock.patch.object(
4444
fake_mode, "allow_non_fake_inputs", True
4545
), fake_mode:
46-
replace_builtin_inplace_ops(gm)
46+
repair_input_aliasing(gm)
4747

4848
# Invoke AOTAutograd to translate operators to aten
4949
gm = aot_export_joint_simple(

0 commit comments

Comments
 (0)