Skip to content

Commit 9834358

Browse files
bdhirshpytorchmergebot
authored andcommitted
Get SchemaCheckMode to error on ops that return inputs directly. Expose as a dynamo backend, eager_debug (pytorch#99744)
Talked to @zou3519 and @ezyang on what the right UX is: tentatively, adding a new dynamo backend is cheap and simple, so it seems worth doing. And longer term, we agreed (?) that it's worth seeing if we can get custom ops sanity asserts to run more automatically, instead of needing a separate backend. Side comment: that actually seems tough: the mode detects secret mutations by cloning every input to every op, running the op, and checking that the data matches between the real input and the cloned input. So I doubt we'll be able to make that behavior always-on? It would need some config at least. Pull Request resolved: pytorch#99744 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
1 parent 1f2d00e commit 9834358

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

test/test_schema_check.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,31 @@
1616
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
1717
sys.path.append(pytorch_test_dir)
1818

19+
def secretly_aliasing(x):
20+
return x.view(-1)
21+
22+
def secretly_mutating(x):
23+
x.mul_(2)
24+
return x * 3
25+
26+
def output_is_input(x):
27+
return x
28+
29+
custom_lib = torch.library.Library("bad_schemas", "DEF")
30+
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
31+
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
32+
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
33+
34+
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")
35+
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
36+
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
37+
custom_lib_cpu.impl("output_is_input", output_is_input)
38+
39+
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")
40+
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
41+
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
42+
custom_lib_meta.impl("output_is_input", output_is_input)
43+
1944
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
2045
# which is then used to test that SchemaCheckMode behaves as expected
2146

@@ -365,6 +390,36 @@ def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
365390
with SchemaCheckMode() as s:
366391
IncorrectAliasTensor(x).aminmax(dim=0)
367392

393+
# When this file was written, python op registration didn't exist.
394+
# It's probably worth re-writing the entire file to use it,
395+
# but instead I just added extra tests.
396+
def test_alias_check_fail_custom_ops_secretly_aliasing(self):
397+
def f(x):
398+
return torch.ops.bad_schemas.secretly_aliasing(x)
399+
400+
x = torch.rand((3, 3))
401+
with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
402+
with SchemaCheckMode() as s:
403+
out = f(x)
404+
405+
def test_alias_check_fail_custom_ops_secretly_mutating(self):
406+
def f(x):
407+
return torch.ops.bad_schemas.secretly_mutating(x)
408+
409+
x = torch.rand((3, 3))
410+
with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
411+
with SchemaCheckMode() as s:
412+
out = f(x)
413+
414+
def test_alias_check_fail_custom_ops_output_is_input(self):
415+
def f(x):
416+
return torch.ops.bad_schemas.output_is_input(x)
417+
418+
x = torch.rand((3, 3))
419+
with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
420+
with SchemaCheckMode() as s:
421+
out = f(x)
422+
368423
# Tests that is_alias_of returns as expected
369424
def test_is_alias_of_basic(self):
370425
x = torch.rand((3, 3), requires_grad=True)

torch/_dynamo/backends/debugging.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ def eager(gm, fake_tensor_inputs):
1818
return gm
1919

2020

21+
@register_backend
22+
def eager_debug(gm, fake_tensor_inputs):
23+
from torch._subclasses.schema_check_mode import SchemaCheckMode
24+
25+
# We could add more debugging bits here.
26+
# Right now, this backend can be used to check for and error on
27+
# custom dispatcher ops that have incorrect schemas.
28+
def inner(*args):
29+
with SchemaCheckMode():
30+
return torch.fx.Interpreter(gm).run(*args)
31+
32+
return inner
33+
34+
2135
@register_backend(name="ts")
2236
def torchscript(gm, fake_tensor_inputs):
2337
return torch.jit.script(gm)

torch/_subclasses/schema_check_mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ def parse_metadata(e):
145145
self.aliasing.append(
146146
Aliasing(func._schema.name, name, f"output_{j}")
147147
)
148+
if after is tuple_out[j] and isinstance(after, torch.Tensor):
149+
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
150+
if not schema_info.is_mutable(
151+
SchemaArgument(SchemaArgType.input, i)
152+
) and func not in [
153+
torch.ops.aten.lift.default,
154+
torch.ops.aten.lift_fresh.default,
155+
]:
156+
raise RuntimeError(
157+
f"""\
158+
Dispatcher operators below autograd are not allowed to directly return inputs.
159+
However, we found that `outputs[{str(j)}] is {name}"""
160+
)
148161
if any(
149162
has_mutated(a, b, c)
150163
for a, b, c in zip(

0 commit comments

Comments
 (0)