Skip to content

Commit a4056cc

Browse files
committed
fix: Remove input aliasing with builtin ops
- Add replacements for inplace builtin operators with their out-of-place equivalents - Add utility to automatically perform replacement prior to AOT tracing - Add test cases to verify inplace operators are replaced accurately
1 parent c1313ea commit a4056cc

File tree

5 files changed

+156
-60
lines changed

5 files changed

+156
-60
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 10 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import logging
44
import unittest
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
import torch.utils._pytree as pytree
109
from torch._dynamo.utils import detect_fake_mode
11-
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._ops import OpOverload
10+
from torch._functorch.aot_autograd import aot_export_joint_simple
1311
from torch_tensorrt.dynamo import CompilationSettings
1412
from torch_tensorrt.dynamo.compile import compile_module
15-
from torch_tensorrt.dynamo.lowering import ATEN_LOWERING_PASSES, get_decompositions
13+
from torch_tensorrt.dynamo.lowering import (
14+
ATEN_LOWERING_PASSES,
15+
get_decompositions,
16+
replace_builtin_inplace_ops,
17+
)
1618
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1719
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1820

@@ -73,8 +75,10 @@ def _pretraced_backend(
7375
with unittest.mock.patch.object(
7476
fake_mode, "allow_non_fake_inputs", True
7577
), fake_mode:
78+
replace_builtin_inplace_ops(gm)
79+
7680
# Invoke AOTAutograd to translate operators to aten
77-
gm = aot_export_for_compile(
81+
gm = aot_export_joint_simple(
7882
gm,
7983
sample_inputs,
8084
decompositions=get_decompositions(
@@ -109,53 +113,3 @@ def _pretraced_backend(
109113
+ "specify pass_through_build_failures=False."
110114
)
111115
raise
112-
113-
114-
def aot_export_for_compile(
115-
func: torch.fx.GraphModule,
116-
args: Sequence[torch.Tensor],
117-
*,
118-
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
119-
) -> torch.fx.GraphModule:
120-
"""Adapted from:
121-
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
122-
123-
Removed check for input aliasing in resultant subgraph - TRT is functional-only
124-
125-
Exports the function to ATen for torch compile
126-
"""
127-
# Trace function with input arguments and decompositions
128-
with torch.no_grad():
129-
fx_g, metadata, in_spec, out_spec = _aot_export_function(
130-
func,
131-
args,
132-
decompositions=decompositions,
133-
)
134-
135-
# No input mutations
136-
if (
137-
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
138-
!= 0
139-
):
140-
raise RuntimeError(
141-
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
142-
)
143-
# No pytrees
144-
if type(in_spec) == pytree.LeafSpec:
145-
raise RuntimeError(
146-
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
147-
)
148-
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
149-
raise RuntimeError(
150-
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
151-
)
152-
if type(out_spec) == pytree.LeafSpec:
153-
raise RuntimeError(
154-
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
155-
)
156-
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
157-
raise RuntimeError(
158-
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
159-
)
160-
161-
return fx_g

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from ._replace_inplace_ops import replace_builtin_inplace_ops
56
from .passes import ATEN_LOWERING_PASSES
67
from .substitutions import * # noqa: F401
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
BUILTIN_TRANSLATION = {
9+
operator.ipow: operator.pow,
10+
operator.imul: operator.mul,
11+
operator.imatmul: operator.matmul,
12+
operator.ifloordiv: operator.floordiv,
13+
operator.itruediv: operator.truediv,
14+
operator.imod: operator.mod,
15+
operator.iadd: operator.add,
16+
operator.isub: operator.sub,
17+
operator.ilshift: operator.lshift,
18+
operator.irshift: operator.rshift,
19+
operator.iand: operator.and_,
20+
operator.ixor: operator.xor,
21+
operator.ior: operator.or_,
22+
}
23+
24+
25+
def replace_builtin_inplace_ops(gm: torch.fx.GraphModule) -> None:
26+
"""Replaces inplace builtins from Python's operator class
27+
28+
Replaces inplace builtins with out-of-place equivalent ops
29+
"""
30+
for node in gm.graph.nodes:
31+
# If a node uses one of the inplace builtins
32+
# Replace it with its out-of-place equivalent
33+
if node.target in BUILTIN_TRANSLATION:
34+
out_of_place_op = BUILTIN_TRANSLATION[node.target]
35+
36+
# Replace inplace operator node and delete
37+
with gm.graph.inserting_before(node):
38+
out_of_place = gm.graph.call_function(
39+
out_of_place_op,
40+
args=node.args,
41+
kwargs=node.kwargs,
42+
)
43+
44+
logger.debug(f"Replacing {node.target} with {out_of_place.target}")
45+
46+
node.replace_all_uses_with(out_of_place)
47+
gm.graph.erase_node(node)
48+
49+
gm.graph.lint()
50+
gm.recompile()

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt
33
from torch.testing._internal.common_utils import TestCase, run_tests
44

5-
from ..testing_utilities import lower_graph_testing
5+
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
66

77

88
class TestFakeTensors(TestCase):
@@ -57,6 +57,7 @@ def forward(self, x):
5757
self.assertAlmostEqual(
5858
max_diff,
5959
0,
60+
DECIMALS_OF_AGREEMENT,
6061
msg=f"MulInt TRT outputs don't match with the original model.",
6162
)
6263
torch._dynamo.reset()
@@ -113,6 +114,7 @@ def forward(self, x):
113114
self.assertAlmostEqual(
114115
max_diff,
115116
0,
117+
DECIMALS_OF_AGREEMENT,
116118
msg=f"AddFloat TRT outputs don't match with the original model.",
117119
)
118120

@@ -157,5 +159,88 @@ def forward(self, x):
157159
torch._dynamo.reset()
158160

159161

162+
class TestInputModifications(TestCase):
163+
def test_input_modifications_add(self):
164+
class InplaceAdd(torch.nn.Module):
165+
def forward(self, x):
166+
x += 3
167+
y = x + 1
168+
return y
169+
170+
inputs = [
171+
torch.rand(
172+
3,
173+
5,
174+
7,
175+
).cuda(),
176+
]
177+
178+
fx_graph = torch.fx.symbolic_trace(InplaceAdd())
179+
180+
# Validate that the results between Torch and Torch-TRT are similar
181+
optimized_model = torch_tensorrt.compile(
182+
fx_graph,
183+
"torch_compile",
184+
inputs,
185+
min_block_size=1,
186+
pass_through_build_failures=True,
187+
)
188+
optimized_model_results = optimized_model(*inputs).detach().cpu()
189+
torch_model_results = fx_graph(*inputs).detach().cpu()
190+
191+
max_diff = float(
192+
torch.max(torch.abs(optimized_model_results - torch_model_results))
193+
)
194+
self.assertAlmostEqual(
195+
max_diff,
196+
0,
197+
DECIMALS_OF_AGREEMENT,
198+
msg=f"InplaceAdd TRT outputs don't match with the original model.",
199+
)
200+
torch._dynamo.reset()
201+
202+
def test_input_modifications_mul(self):
203+
class InplaceMul(torch.nn.Module):
204+
def forward(self, x):
205+
x *= 5.0
206+
x *= 1.9
207+
y = x + 1
208+
y /= 1.3
209+
return y
210+
211+
inputs = [
212+
torch.rand(
213+
1,
214+
3,
215+
5,
216+
7,
217+
).cuda(),
218+
]
219+
220+
fx_graph = torch.fx.symbolic_trace(InplaceMul())
221+
222+
# Validate that the results between Torch and Torch-TRT are similar
223+
optimized_model = torch_tensorrt.compile(
224+
fx_graph,
225+
"torch_compile",
226+
inputs,
227+
min_block_size=1,
228+
pass_through_build_failures=True,
229+
)
230+
optimized_model_results = optimized_model(*inputs).detach().cpu()
231+
torch_model_results = fx_graph(*inputs).detach().cpu()
232+
233+
max_diff = float(
234+
torch.max(torch.abs(optimized_model_results - torch_model_results))
235+
)
236+
self.assertAlmostEqual(
237+
max_diff,
238+
0,
239+
DECIMALS_OF_AGREEMENT,
240+
msg=f"InplaceMul TRT outputs don't match with the original model.",
241+
)
242+
torch._dynamo.reset()
243+
244+
160245
if __name__ == "__main__":
161246
run_tests()

tests/py/dynamo/testing_utilities.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55

66
import torch
77
from torch._dynamo.utils import detect_fake_mode
8+
from torch._functorch.aot_autograd import aot_export_joint_simple
89
from torch_tensorrt.dynamo import partitioning
9-
from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile
10-
from torch_tensorrt.dynamo.lowering import ATEN_LOWERING_PASSES, get_decompositions
10+
from torch_tensorrt.dynamo.lowering import (
11+
ATEN_LOWERING_PASSES,
12+
get_decompositions,
13+
replace_builtin_inplace_ops,
14+
)
1115
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1216

1317
DECIMALS_OF_AGREEMENT = 4
@@ -39,8 +43,10 @@ def fx_dynamo_testing_backend(
3943
with unittest.mock.patch.object(
4044
fake_mode, "allow_non_fake_inputs", True
4145
), fake_mode:
46+
replace_builtin_inplace_ops(gm)
47+
4248
# Invoke AOTAutograd to translate operators to aten
43-
gm = aot_export_for_compile(
49+
gm = aot_export_joint_simple(
4450
gm,
4551
sample_inputs,
4652
decompositions=get_decompositions(),

0 commit comments

Comments
 (0)