Skip to content

Commit 7daa112

Browse files
authored
fix: Remove input aliasing of builtin ops (#2276)
1 parent ecdc040 commit 7daa112

File tree

10 files changed

+238
-75
lines changed

10 files changed

+238
-75
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import logging
44
import unittest
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
import torch.utils._pytree as pytree
109
from torch._dynamo.utils import detect_fake_mode
11-
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._ops import OpOverload
10+
from torch._functorch.aot_autograd import aot_export_joint_simple
1311
from torch_tensorrt.dynamo import CompilationSettings
1412
from torch_tensorrt.dynamo.compile import compile_module
15-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
13+
from torch_tensorrt.dynamo.lowering import (
14+
apply_lowering_passes,
15+
get_decompositions,
16+
repair_input_aliasing,
17+
)
1618
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1719
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level
1820

@@ -71,10 +73,13 @@ def _pretraced_backend(
7173
with unittest.mock.patch.object(
7274
fake_mode, "allow_non_fake_inputs", True
7375
), fake_mode:
76+
repair_input_aliasing(gm)
77+
7478
# Invoke AOTAutograd to translate operators to aten
75-
gm = aot_export_for_compile(
79+
gm = aot_export_joint_simple(
7680
gm,
7781
sample_inputs,
82+
trace_joint=False,
7883
decompositions=get_decompositions(
7984
settings.enable_experimental_decompositions
8085
),
@@ -107,53 +112,3 @@ def _pretraced_backend(
107112
+ "specify pass_through_build_failures=False."
108113
)
109114
raise
110-
111-
112-
def aot_export_for_compile(
113-
func: torch.fx.GraphModule,
114-
args: Sequence[torch.Tensor],
115-
*,
116-
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
117-
) -> torch.fx.GraphModule:
118-
"""Adapted from:
119-
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
120-
121-
Removed check for input aliasing in resultant subgraph - TRT is functional-only
122-
123-
Exports the function to ATen for torch compile
124-
"""
125-
# Trace function with input arguments and decompositions
126-
with torch.no_grad():
127-
fx_g, metadata, in_spec, out_spec = _aot_export_function(
128-
func,
129-
args,
130-
decompositions=decompositions,
131-
)
132-
133-
# No input mutations
134-
if (
135-
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
136-
!= 0
137-
):
138-
raise RuntimeError(
139-
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
140-
)
141-
# No pytrees
142-
if type(in_spec) == pytree.LeafSpec:
143-
raise RuntimeError(
144-
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
145-
)
146-
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
147-
raise RuntimeError(
148-
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
149-
)
150-
if type(out_spec) == pytree.LeafSpec:
151-
raise RuntimeError(
152-
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
153-
)
154-
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
155-
raise RuntimeError(
156-
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
157-
)
158-
159-
return fx_g

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from ._repair_input_aliasing import repair_input_aliasing
56
from .passes import apply_lowering_passes
67
from .substitutions import * # noqa: F401
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import get_tensor_placeholders
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def repair_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
10+
"""Inserts clone operators temporarily ahead of every placeholder
11+
12+
See: https://github.com/pytorch/pytorch/issues/108079
13+
Undone by `remove_input_alias_fixing_clones` after tracing
14+
"""
15+
# Extract graph placeholder Tensors
16+
placeholders = get_tensor_placeholders(gm)
17+
18+
for node in placeholders:
19+
# Insert clones for placeholder nodes to avoid
20+
# input aliasing or mutation
21+
with gm.graph.inserting_after(placeholders[-1]):
22+
cloned_input = gm.graph.call_function(
23+
torch.ops.aten.clone.default,
24+
args=(node,),
25+
)
26+
27+
# Replace all uses of the placeholder except the cloned node
28+
# with the cloned placeholder
29+
node.replace_all_uses_with(
30+
cloned_input,
31+
delete_user_cb=lambda node: node != cloned_input,
32+
)
33+
34+
gm.graph.lint()
35+
gm.recompile()
36+
logger.debug(f"Inserted auxiliary clone nodes for placeholders:\n{gm.graph}")
37+
38+
return gm

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
from .constant_folding import constant_fold
77
from .pass_manager import DynamoPassManager
8+
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
89
from .repair_input_as_output import repair_input_as_output
910

1011
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1112
[
13+
remove_input_alias_fixing_clones,
1214
constant_fold,
1315
repair_input_as_output,
1416
]

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import torch
44
from torch_tensorrt._utils import sanitized_torch_version
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
58

69
from packaging import version
710

@@ -47,9 +50,7 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
4750
for node in erased_params:
4851
gm.graph.erase_node(node)
4952

50-
gm.graph.eliminate_dead_code()
51-
gm.graph.lint()
52-
gm.recompile()
53+
gm = clean_up_graph_after_modifications(gm)
5354

5455
logger.debug(f"Graph after constant folding:\n{gm.graph}")
5556

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List
2+
3+
import torch
4+
5+
6+
def clean_up_graph_after_modifications(
7+
gm: torch.fx.GraphModule,
8+
) -> torch.fx.GraphModule:
9+
"""Runs dead-code elimination, linting, and recompilation for graph, in-place"""
10+
gm.graph.eliminate_dead_code()
11+
gm.graph.lint()
12+
gm.recompile()
13+
return gm
14+
15+
16+
def get_tensor_placeholders(
17+
gm: torch.fx.GraphModule,
18+
) -> List[torch.fx.Node]:
19+
"""Returns placeholder nodes of GraphModule which are torch.Tensor types"""
20+
# Tensor placeholders must be subclasses of torch.Tensor
21+
placeholders = [
22+
node
23+
for node in gm.graph.nodes
24+
if (
25+
node.op == "placeholder"
26+
and isinstance(node.type, type)
27+
and issubclass(node.type, torch.Tensor)
28+
)
29+
]
30+
31+
return placeholders
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import logging
2+
3+
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
)
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
# TODO: Delete this lowering pass once aot_export_joint_simple is patched
12+
def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
"""Remove the auxiliary clone nodes inserted to fix input aliasing
14+
15+
See: https://github.com/pytorch/pytorch/issues/108079
16+
"""
17+
modified_graph = False
18+
19+
for node in gm.graph.nodes:
20+
# If the node is a placeholder and its only user is a clone node
21+
# it was modified by the input alias-fixing pass, and the change
22+
# needs to be undone
23+
if (
24+
node.op == "placeholder"
25+
and len(node.users) == 1
26+
and list(node.users)[0].target == torch.ops.aten.clone.default
27+
):
28+
modified_graph = True
29+
30+
# Replace all uses of the clone with the placholder, delete the clone
31+
clone_node = list(node.users)[0]
32+
logger.debug(
33+
f"Removing node {clone_node} from graph, since it is a clone node which "
34+
f"is the only user of placeholder {node} and was inserted by the compiler."
35+
)
36+
clone_node.replace_all_uses_with(node)
37+
gm.graph.erase_node(clone_node)
38+
39+
if modified_graph:
40+
gm = clean_up_graph_after_modifications(gm)
41+
logger.debug(f"Removed auxiliary clone nodes for placeholders:\n{gm.graph}")
42+
43+
return gm
Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
5+
clean_up_graph_after_modifications,
6+
get_tensor_placeholders,
7+
)
48

59
logger = logging.getLogger(__name__)
610

@@ -13,15 +17,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1317
modified_graph = False
1418

1519
# Extract graph placeholder Tensors
16-
placeholders = [
17-
node
18-
for node in gm.graph.nodes
19-
if (
20-
node.op == "placeholder"
21-
and isinstance(node.type, type)
22-
and issubclass(node.type, torch.Tensor)
23-
)
24-
]
20+
placeholders = get_tensor_placeholders(gm)
2521

2622
for placeholder in placeholders:
2723
# If any placeholder has any users which are direct graph outputs
@@ -34,7 +30,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
3430
direct_outputs = [user for user in placeholder.users if user.op == "output"]
3531

3632
# Insert clone node for placeholder to ensure placeholder is not a direct output
37-
with gm.graph.inserting_after(placeholder):
33+
with gm.graph.inserting_after(placeholders[-1]):
3834
cloned_placeholder = gm.graph.call_function(
3935
torch.ops.aten.clone.default,
4036
args=(placeholder,),
@@ -45,9 +41,7 @@ def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
4541
output.replace_input_with(placeholder, cloned_placeholder)
4642

4743
if modified_graph:
48-
gm.graph.eliminate_dead_code()
49-
gm.graph.lint()
50-
gm.recompile()
44+
gm = clean_up_graph_after_modifications(gm)
5145
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}")
5246

5347
return gm

0 commit comments

Comments
 (0)