Skip to content

Commit 0f10a31

Browse files
author
Wei Wei
committed
[fx2trt] exception for inplace op (#82)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/82 `check_mutable_operations` does not work. We override `create_node` function to throw exception for inplace Reviewed By: jfix71 Differential Revision: D36556398 fbshipit-source-id: db35732007eae77fb40500753d7ffe26d43e659c
1 parent 86e6e47 commit 0f10a31

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

test/tracer/test_acc_tracer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,34 @@ def forward(self, a):
23822382
torch.equal(ref, res), f"Tensors at don't match {ref=} {res=}"
23832383
)
23842384

2385+
def test_inplace_raise(self):
2386+
"""
2387+
Test that encountering inplace is raised for exception
2388+
"""
2389+
2390+
class TestModule(nn.Module):
2391+
def __init__(self):
2392+
super().__init__()
2393+
2394+
def forward(self, a):
2395+
a = a + 2
2396+
a.sub_(3)
2397+
return a
2398+
2399+
m = TestModule()
2400+
in_a = torch.randn(5)
2401+
try:
2402+
acc_tracer.trace(
2403+
m,
2404+
[in_a],
2405+
)
2406+
self.fail("Shouldn't get here because exception should be thrown.")
2407+
except RuntimeError as e:
2408+
self.assertEqual(
2409+
"Tried to trace mutable operation sub_. FX only supports functional code",
2410+
str(e),
2411+
)
2412+
23852413
def test_repeat_interleave(self):
23862414
class TestModule(nn.Module):
23872415
def __init__(self):

tracer/acc_tracer/acc_tracer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch._sources import normalize_source_lines
1919
from torch.fx import Graph, Tracer
2020
from torch.fx.experimental.normalize import NormalizeArgs
21+
from torch.fx.node import Argument, Node, Target
2122
from torch.fx.passes import shape_prop
2223

2324

@@ -242,6 +243,42 @@ def trace(
242243
rewritten = _rewrite(root, ast_rewriter_allow_list, self.leaf_module_list)
243244
return super().trace(rewritten, concrete_args), rewritten
244245

246+
# override TraceBase's method
247+
def create_node(
248+
self,
249+
kind: str,
250+
target: Target,
251+
args: Tuple[Argument, ...],
252+
kwargs: Dict[str, Argument],
253+
name: Optional[str] = None,
254+
type_expr: Optional[Any] = None,
255+
) -> Node:
256+
"""
257+
Inserts a graph node given target, args, kwargs, and name.
258+
259+
This method can be overridden to do extra checking, validation, or
260+
modification of values used in node creation. For example, one might
261+
want to disallow in-place operations from being recorded.
262+
"""
263+
264+
## Hacky way to decide inplace ops
265+
if type(target) != str:
266+
name_target = target.__name__
267+
else:
268+
name_target = target
269+
270+
allow_list = ["and_", "or_"] # python operator.and_, operator.or_
271+
if (
272+
name_target[-1] == "_"
273+
and name_target[0] != "_"
274+
and not (name_target in allow_list)
275+
):
276+
raise RuntimeError(
277+
f"Tried to trace mutable operation {name_target}. FX only supports functional code"
278+
)
279+
280+
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
281+
245282

246283
# List of modules that need rewriting to be supported for tracing.
247284
DEFAULT_REWRITE_ALLOW_LIST = {

0 commit comments

Comments
 (0)