Skip to content

Commit 80a8da2

Browse files
committed
fix: Remove input aliasing with builtin ops
- Add replacements for inplace builtin operators with their out-of-place equivalents - Add utility to automatically perform replacement prior to AOT tracing - Add test cases to verify inplace operators are replaced accurately
1 parent 90eec47 commit 80a8da2

File tree

5 files changed

+155
-59
lines changed

5 files changed

+155
-59
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 10 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import logging
44
import unittest
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
import torch.utils._pytree as pytree
109
from torch._dynamo.utils import detect_fake_mode
11-
from torch._functorch.aot_autograd import _aot_export_function
12-
from torch._ops import OpOverload
10+
from torch._functorch.aot_autograd import aot_export_joint_simple
1311
from torch_tensorrt.dynamo import CompilationSettings
1412
from torch_tensorrt.dynamo.compile import compile_module
15-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
13+
from torch_tensorrt.dynamo.lowering import (
14+
apply_lowering_passes,
15+
get_decompositions,
16+
replace_builtin_inplace_ops,
17+
)
1618
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1719
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1820

@@ -74,8 +76,10 @@ def _pretraced_backend(
7476
with unittest.mock.patch.object(
7577
fake_mode, "allow_non_fake_inputs", True
7678
), fake_mode:
79+
replace_builtin_inplace_ops(gm)
80+
7781
# Invoke AOTAutograd to translate operators to aten
78-
gm = aot_export_for_compile(
82+
gm = aot_export_joint_simple(
7983
gm,
8084
sample_inputs,
8185
decompositions=get_decompositions(
@@ -110,53 +114,3 @@ def _pretraced_backend(
110114
+ "specify pass_through_build_failures=False."
111115
)
112116
raise
113-
114-
115-
def aot_export_for_compile(
116-
func: torch.fx.GraphModule,
117-
args: Sequence[torch.Tensor],
118-
*,
119-
decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None,
120-
) -> torch.fx.GraphModule:
121-
"""Adapted from:
122-
https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158
123-
124-
Removed check for input aliasing in resultant subgraph - TRT is functional-only
125-
126-
Exports the function to ATen for torch compile
127-
"""
128-
# Trace function with input arguments and decompositions
129-
with torch.no_grad():
130-
fx_g, metadata, in_spec, out_spec = _aot_export_function(
131-
func,
132-
args,
133-
decompositions=decompositions,
134-
)
135-
136-
# No input mutations
137-
if (
138-
len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata])
139-
!= 0
140-
):
141-
raise RuntimeError(
142-
f"aot_export_joint_simple does not support input mutations. {str(metadata)}"
143-
)
144-
# No pytrees
145-
if type(in_spec) == pytree.LeafSpec:
146-
raise RuntimeError(
147-
f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
148-
)
149-
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
150-
raise RuntimeError(
151-
f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
152-
)
153-
if type(out_spec) == pytree.LeafSpec:
154-
raise RuntimeError(
155-
f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
156-
)
157-
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
158-
raise RuntimeError(
159-
f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
160-
)
161-
162-
return fx_g

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._fusers import * # noqa: F401
33
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
44
from ._pre_aot_lowering import register_substitution # noqa: F401
5+
from ._replace_inplace_ops import replace_builtin_inplace_ops
56
from .passes import add_lowering_pass, apply_lowering_passes
67
from .substitutions import * # noqa: F401
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
BUILTIN_TRANSLATION = {
9+
operator.ipow: operator.pow,
10+
operator.imul: operator.mul,
11+
operator.imatmul: operator.matmul,
12+
operator.ifloordiv: operator.floordiv,
13+
operator.itruediv: operator.truediv,
14+
operator.imod: operator.mod,
15+
operator.iadd: operator.add,
16+
operator.isub: operator.sub,
17+
operator.ilshift: operator.lshift,
18+
operator.irshift: operator.rshift,
19+
operator.iand: operator.and_,
20+
operator.ixor: operator.xor,
21+
operator.ior: operator.or_,
22+
}
23+
24+
25+
def replace_builtin_inplace_ops(gm: torch.fx.GraphModule) -> None:
26+
"""Replaces inplace builtins from Python's operator class
27+
28+
Replaces inplace builtins with out-of-place equivalent ops
29+
"""
30+
for node in gm.graph.nodes:
31+
# If a node uses one of the inplace builtins
32+
# Replace it with its out-of-place equivalent
33+
if node.target in BUILTIN_TRANSLATION:
34+
out_of_place_op = BUILTIN_TRANSLATION[node.target]
35+
36+
# Replace inplace operator node and delete
37+
with gm.graph.inserting_before(node):
38+
out_of_place = gm.graph.call_function(
39+
out_of_place_op,
40+
args=node.args,
41+
kwargs=node.kwargs,
42+
)
43+
44+
logger.debug(f"Replacing {node.target} with {out_of_place.target}")
45+
46+
node.replace_all_uses_with(out_of_place)
47+
gm.graph.erase_node(node)
48+
49+
gm.graph.lint()
50+
gm.recompile()

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def forward(self, x):
5757
self.assertAlmostEqual(
5858
max_diff,
5959
0,
60+
DECIMALS_OF_AGREEMENT,
6061
msg=f"MulInt TRT outputs don't match with the original model.",
6162
)
6263
torch._dynamo.reset()
@@ -113,6 +114,7 @@ def forward(self, x):
113114
self.assertAlmostEqual(
114115
max_diff,
115116
0,
117+
DECIMALS_OF_AGREEMENT,
116118
msg=f"AddFloat TRT outputs don't match with the original model.",
117119
)
118120

@@ -236,5 +238,88 @@ def forward(self, x):
236238
torch._dynamo.reset()
237239

238240

241+
class TestInputModifications(TestCase):
242+
def test_input_modifications_add(self):
243+
class InplaceAdd(torch.nn.Module):
244+
def forward(self, x):
245+
x += 3
246+
y = x + 1
247+
return y
248+
249+
inputs = [
250+
torch.rand(
251+
3,
252+
5,
253+
7,
254+
).cuda(),
255+
]
256+
257+
fx_graph = torch.fx.symbolic_trace(InplaceAdd())
258+
259+
# Validate that the results between Torch and Torch-TRT are similar
260+
optimized_model = torch_tensorrt.compile(
261+
fx_graph,
262+
"torch_compile",
263+
inputs,
264+
min_block_size=1,
265+
pass_through_build_failures=True,
266+
)
267+
optimized_model_results = optimized_model(*inputs).detach().cpu()
268+
torch_model_results = fx_graph(*inputs).detach().cpu()
269+
270+
max_diff = float(
271+
torch.max(torch.abs(optimized_model_results - torch_model_results))
272+
)
273+
self.assertAlmostEqual(
274+
max_diff,
275+
0,
276+
DECIMALS_OF_AGREEMENT,
277+
msg=f"InplaceAdd TRT outputs don't match with the original model.",
278+
)
279+
torch._dynamo.reset()
280+
281+
def test_input_modifications_mul(self):
282+
class InplaceMul(torch.nn.Module):
283+
def forward(self, x):
284+
x *= 5.0
285+
x *= 1.9
286+
y = x + 1
287+
y /= 1.3
288+
return y
289+
290+
inputs = [
291+
torch.rand(
292+
1,
293+
3,
294+
5,
295+
7,
296+
).cuda(),
297+
]
298+
299+
fx_graph = torch.fx.symbolic_trace(InplaceMul())
300+
301+
# Validate that the results between Torch and Torch-TRT are similar
302+
optimized_model = torch_tensorrt.compile(
303+
fx_graph,
304+
"torch_compile",
305+
inputs,
306+
min_block_size=1,
307+
pass_through_build_failures=True,
308+
)
309+
optimized_model_results = optimized_model(*inputs).detach().cpu()
310+
torch_model_results = fx_graph(*inputs).detach().cpu()
311+
312+
max_diff = float(
313+
torch.max(torch.abs(optimized_model_results - torch_model_results))
314+
)
315+
self.assertAlmostEqual(
316+
max_diff,
317+
0,
318+
DECIMALS_OF_AGREEMENT,
319+
msg=f"InplaceMul TRT outputs don't match with the original model.",
320+
)
321+
torch._dynamo.reset()
322+
323+
239324
if __name__ == "__main__":
240325
run_tests()

tests/py/dynamo/testing_utilities.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55

66
import torch
77
from torch._dynamo.utils import detect_fake_mode
8+
from torch._functorch.aot_autograd import aot_export_joint_simple
89
from torch_tensorrt.dynamo import partitioning
9-
from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile
10-
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
10+
from torch_tensorrt.dynamo.lowering import (
11+
apply_lowering_passes,
12+
get_decompositions,
13+
replace_builtin_inplace_ops,
14+
)
1115
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1216

1317
DECIMALS_OF_AGREEMENT = 4
@@ -39,8 +43,10 @@ def fx_dynamo_testing_backend(
3943
with unittest.mock.patch.object(
4044
fake_mode, "allow_non_fake_inputs", True
4145
), fake_mode:
46+
replace_builtin_inplace_ops(gm)
47+
4248
# Invoke AOTAutograd to translate operators to aten
43-
gm = aot_export_for_compile(
49+
gm = aot_export_joint_simple(
4450
gm,
4551
sample_inputs,
4652
decompositions=get_decompositions(),

0 commit comments

Comments
 (0)