|
16 | 16 | pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
17 | 17 | sys.path.append(pytorch_test_dir)
|
18 | 18 |
|
| 19 | +def secretly_aliasing(x): |
| 20 | + return x.view(-1) |
| 21 | + |
| 22 | +def secretly_mutating(x): |
| 23 | + x.mul_(2) |
| 24 | + return x * 3 |
| 25 | + |
| 26 | +def output_is_input(x): |
| 27 | + return x |
| 28 | + |
| 29 | +custom_lib = torch.library.Library("bad_schemas", "DEF") |
| 30 | +custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") |
| 31 | +custom_lib.define("secretly_mutating(Tensor x) -> Tensor") |
| 32 | +custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") |
| 33 | + |
| 34 | +custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") |
| 35 | +custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) |
| 36 | +custom_lib_cpu.impl("secretly_mutating", secretly_mutating) |
| 37 | +custom_lib_cpu.impl("output_is_input", output_is_input) |
| 38 | + |
| 39 | +custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") |
| 40 | +custom_lib_meta.impl("secretly_aliasing", secretly_aliasing) |
| 41 | +custom_lib_meta.impl("secretly_mutating", secretly_mutating) |
| 42 | +custom_lib_meta.impl("output_is_input", output_is_input) |
| 43 | + |
19 | 44 | # This TorchDispatchTensor Subclass is used to simulate an incorrect schema
|
20 | 45 | # which is then used to test that SchemaCheckMode behaves as expected
|
21 | 46 |
|
@@ -365,6 +390,36 @@ def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
|
365 | 390 | with SchemaCheckMode() as s:
|
366 | 391 | IncorrectAliasTensor(x).aminmax(dim=0)
|
367 | 392 |
|
| 393 | + # When this file was written, python op registration didn't exist. |
| 394 | + # It's probably worth re-writing the entire file to use it, |
| 395 | + # but instead I just added extra tests. |
| 396 | + def test_alias_check_fail_custom_ops_secretly_aliasing(self): |
| 397 | + def f(x): |
| 398 | + return torch.ops.bad_schemas.secretly_aliasing(x) |
| 399 | + |
| 400 | + x = torch.rand((3, 3)) |
| 401 | + with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): |
| 402 | + with SchemaCheckMode() as s: |
| 403 | + out = f(x) |
| 404 | + |
| 405 | + def test_alias_check_fail_custom_ops_secretly_mutating(self): |
| 406 | + def f(x): |
| 407 | + return torch.ops.bad_schemas.secretly_mutating(x) |
| 408 | + |
| 409 | + x = torch.rand((3, 3)) |
| 410 | + with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): |
| 411 | + with SchemaCheckMode() as s: |
| 412 | + out = f(x) |
| 413 | + |
| 414 | + def test_alias_check_fail_custom_ops_output_is_input(self): |
| 415 | + def f(x): |
| 416 | + return torch.ops.bad_schemas.output_is_input(x) |
| 417 | + |
| 418 | + x = torch.rand((3, 3)) |
| 419 | + with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): |
| 420 | + with SchemaCheckMode() as s: |
| 421 | + out = f(x) |
| 422 | + |
368 | 423 | # Tests that is_alias_of returns as expected
|
369 | 424 | def test_is_alias_of_basic(self):
|
370 | 425 | x = torch.rand((3, 3), requires_grad=True)
|
|
0 commit comments