Skip to content

Commit c63e411

Browse files
committed
fix: Repair input aliasing with clone insertion
1 parent 80a8da2 commit c63e411

File tree

11 files changed

+138
-75
lines changed

11 files changed

+138
-75
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering import (
1414
apply_lowering_passes,
1515
get_decompositions,
16-
replace_builtin_inplace_ops,
16+
repair_input_aliasing,
1717
)
1818
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1919
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
@@ -76,12 +76,13 @@ def _pretraced_backend(
7676
with unittest.mock.patch.object(
7777
fake_mode, "allow_non_fake_inputs", True
7878
), fake_mode:
79-
replace_builtin_inplace_ops(gm)
79+
repair_input_aliasing(gm)
8080

8181
# Invoke AOTAutograd to translate operators to aten
8282
gm = aot_export_joint_simple(
8383
gm,
8484
sample_inputs,
85+
trace_joint=False,
8586
decompositions=get_decompositions(
8687
settings.enable_experimental_decompositions
8788
),

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5-
from ._replace_inplace_ops import replace_builtin_inplace_ops
5+
from ._repair_input_aliasing import repair_input_aliasing
66
from .passes import add_lowering_pass, apply_lowering_passes
77
from .substitutions import * # noqa: F401
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import get_tensor_placeholders
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
10+
"""Inserts clone operators temporarily ahead of every placeholder
11+
12+
See: https://github.com/pytorch/pytorch/issues/108079
13+
Undone by `remove_input_alias_fixing_clones` after tracing
14+
"""
15+
# Extract graph placeholder Tensors
16+
placeholders = get_tensor_placeholders(gm)
17+
18+
for node in placeholders:
19+
# Insert clones for placeholder nodes to avoid
20+
# input aliasing or mutation
21+
with gm.graph.inserting_after(placeholders[-1]):
22+
cloned_input = gm.graph.call_function(
23+
torch.ops.aten.clone.default,
24+
args=(node,),
25+
)
26+
27+
# Replace all uses of the placeholder except the cloned node
28+
# with the cloned placeholder
29+
node.replace_all_uses_with(
30+
cloned_input,
31+
delete_user_cb=lambda node: node != cloned_input,
32+
)
33+
34+
gm.graph.lint()
35+
gm.recompile()
36+
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")
37+
38+
return gm

py/torch_tensorrt/dynamo/lowering/_replace_inplace_ops.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

py/torch_tensorrt/dynamo/lowering/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
# Import and order lowering passes and pass manager
77
from .constant_folding import constant_fold
88
from .pass_manager import DynamoPassManager
9+
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
910
from .repair_input_as_output import repair_input_as_output
1011

1112
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1213
[
14+
remove_input_alias_fixing_clones,
1315
constant_fold,
1416
repair_input_as_output,
1517
]

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import torch
44
from torch_tensorrt._utils import sanitized_torch_version
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
58

69
from packaging import version
710

@@ -47,9 +50,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
4750
for node in erased_params:
4851
gm.graph.erase_node(node)
4952

50-
gm.graph.eliminate_dead_code()
51-
gm.graph.lint()
52-
gm.recompile()
53+
gm = clean_up_graph_after_modifications(gm)
5354

5455
logger.debug(f"Graph after constant folding:\n{gm.graph}")
5556

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List
2+
3+
import torch
4+
5+
6+
def clean_up_graph_after_modifications(
7+
gm: torch.fx.GraphModule,
8+
) -> torch.fx.GraphModule:
9+
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
10+
gm.graph.eliminate_dead_code()
11+
gm.graph.lint()
12+
gm.recompile()
13+
return gm
14+
15+
16+
def get_tensor_placeholders(
17+
gm: torch.fx.GraphModule,
18+
) -> List[torch.fx.Node]:
19+
"""Returns placeholder nodes of GraphModule which are torch.Tensor types"""
20+
# Tensor placeholders must be subclasses of torch.Tensor
21+
placeholders = [
22+
node
23+
for node in gm.graph.nodes
24+
if (
25+
node.op == "placeholder"
26+
and isinstance(node.type, type)
27+
and issubclass(node.type, torch.Tensor)
28+
)
29+
]
30+
31+
return placeholders
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
)
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
12+
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
"""Remove the auxiliary clone nodes inserted to fix input aliasing
14+
15+
See: https://github.com/pytorch/pytorch/issues/108079
16+
"""
17+
modified_graph = False
18+
19+
for node in gm.graph.nodes:
20+
# If the node is a placeholder and its only user is a clone node
21+
# it was modified by the input alias-fixing pass, and the change
22+
# needs to be undone
23+
if (
24+
node.op == "placeholder"
25+
and len(node.users) == 1
26+
and list(node.users)[0].target == torch.ops.aten.clone.default
27+
):
28+
modified_graph = True
29+
30+
# Replace all uses of the clone with the placholder, delete the clone
31+
clone_node = list(node.users)[0]
32+
clone_node.replace_all_uses_with(node)
33+
gm.graph.erase_node(clone_node)
34+
35+
if modified_graph:
36+
gm = clean_up_graph_after_modifications(gm)
37+
logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}")
38+
39+
return gm
Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
get_tensor_placeholders,
7+
)
48

59
logger = logging.getLogger(__name__)
610

@@ -13,15 +17,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1317
modified_graph = False
1418

1519
# Extract graph placeholder Tensors
16-
placeholders = [
17-
node
18-
for node in gm.graph.nodes
19-
if (
20-
node.op == "placeholder"
21-
and isinstance(node.type, type)
22-
and issubclass(node.type, torch.Tensor)
23-
)
24-
]
20+
placeholders = get_tensor_placeholders(gm)
2521

2622
for placeholder in placeholders:
2723
# If any placeholder has any users which are direct graph outputs
@@ -34,7 +30,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3430
direct_outputs = [user for user in placeholder.users if user.op == "output"]
3531

3632
# Insert clone node for placeholder to ensure placeholder is not a direct output
37-
with gm.graph.inserting_after(placeholder):
33+
with gm.graph.inserting_after(placeholders[-1]):
3834
cloned_placeholder = gm.graph.call_function(
3935
torch.ops.aten.clone.default,
4036
args=(placeholder,),
@@ -45,9 +41,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
4541
output.replace_input_with(placeholder, cloned_placeholder)
4642

4743
if modified_graph:
48-
gm.graph.eliminate_dead_code()
49-
gm.graph.lint()
50-
gm.recompile()
44+
gm = clean_up_graph_after_modifications(gm)
5145
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
5246

5347
return gm

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,12 @@ def forward(self, x):
280280

281281
def test_input_modifications_mul(self):
282282
class InplaceMul(torch.nn.Module):
283-
def forward(self, x):
283+
def forward(self, x, y):
284284
x *= 5.0
285285
x *= 1.9
286-
y = x + 1
287-
y /= 1.3
288-
return y
286+
z = x + y
287+
z /= 1.3
288+
return z
289289

290290
inputs = [
291291
torch.rand(
@@ -294,6 +294,12 @@ def forward(self, x):
294294
5,
295295
7,
296296
).cuda(),
297+
torch.rand(
298+
1,
299+
3,
300+
5,
301+
7,
302+
).cuda(),
297303
]
298304

299305
fx_graph = torch.fx.symbolic_trace(InplaceMul())

tests/py/dynamo/testing_utilities.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch_tensorrt.dynamo.lowering import (
1111
apply_lowering_passes,
1212
get_decompositions,
13-
replace_builtin_inplace_ops,
13+
repair_input_aliasing,
1414
)
1515
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1616

@@ -43,12 +43,13 @@ def fx_dynamo_testing_backend(
4343
with unittest.mock.patch.object(
4444
fake_mode, "allow_non_fake_inputs", True
4545
), fake_mode:
46-
replace_builtin_inplace_ops(gm)
46+
repair_input_aliasing(gm)
4747

4848
# Invoke AOTAutograd to translate operators to aten
4949
gm = aot_export_joint_simple(
5050
gm,
5151
sample_inputs,
52+
trace_joint=False,
5253
decompositions=get_decompositions(),
5354
)
5455

0 commit comments

Comments
 (0)