Skip to content

Commit 1f2d00e

Browse files
bdhirshpytorchmergebot
authored andcommitted
move SchemaCheckMode to torch/_subclasses (pytorch#99743)
Pull Request resolved: pytorch#99743 Approved by: https://github.com/albanD
1 parent 884c5c8 commit 1f2d00e

File tree

3 files changed

+64
-32
lines changed

3 files changed

+64
-32
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ exclude_patterns = [
116116
'torch/_functorch/partitioners.py',
117117
'torch/_functorch/top_operators_github_usage.py',
118118
'torch/_functorch/vmap.py',
119+
'torch/_subclasses/schema_check_mode.py',
119120
'torch/distributed/elastic/agent/server/api.py',
120121
'torch/testing/_internal/**',
121122
'torch/distributed/fsdp/fully_sharded_data_parallel.py',

test/test_schema_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torch.testing._internal.common_utils import run_tests
1010
from torch.fx.operator_schemas import normalize_function
11-
from torch.testing._internal.schema_check_mode import SchemaCheckMode
11+
from torch._subclasses.schema_check_mode import SchemaCheckMode
1212
from torch.utils._python_dispatch import TorchDispatchMode
1313
from torch.testing._internal.common_methods_invocations import op_db
1414
from torch.testing._internal.jit_utils import JitTestCase

torch/testing/_internal/schema_check_mode.py renamed to torch/_subclasses/schema_check_mode.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from collections import namedtuple
2+
from copy import deepcopy
3+
from itertools import combinations
4+
15
import torch
2-
from torch.utils._pytree import tree_flatten, tree_map
36
from torch.fx.operator_schemas import normalize_function
47
from torch.testing._internal.jit_utils import clone_inputs
58
from torch.utils._python_dispatch import TorchDispatchMode
6-
from itertools import combinations
7-
from collections import namedtuple
8-
from copy import deepcopy
9+
from torch.utils._pytree import tree_flatten, tree_map
910

1011
# Named Tuples used within SchemaCheckMode
11-
Mutation = namedtuple('Mutation', ['op_name', 'arg_name'])
12-
Aliasing = namedtuple('Aliasing', ['op_name', 'arg_name', 'output_number'])
12+
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
13+
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
1314

1415
# Simplified naming for C++ classes
1516
SchemaArgument = torch._C._SchemaArgument
@@ -22,6 +23,7 @@
2223
# - Checks for mutations on all inputs
2324
# - Checks for aliasing on all inputs
2425

26+
2527
class SchemaCheckMode(TorchDispatchMode):
2628
def __init__(self):
2729
# Information recorded for testing purposes. For example:
@@ -42,12 +44,16 @@ def display_ops(self):
4244
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
4345
def has_mutated(before, after, md):
4446
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
45-
if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr:
47+
if (
48+
are_tensors
49+
and before.layout != torch.sparse_csr
50+
and after.layout != torch.sparse_csr
51+
):
4652
return not (
47-
before.size() == after.size() and
48-
torch.allclose(before, after, equal_nan=True) and
49-
md[0] == after.stride() and
50-
md[1] == after._typed_storage()._cdata
53+
before.size() == after.size()
54+
and torch.allclose(before, after, equal_nan=True)
55+
and md[0] == after.stride()
56+
and md[1] == after._typed_storage()._cdata
5157
)
5258
return False
5359

@@ -76,31 +82,38 @@ def parse_metadata(e):
7682
if not type(e) == torch.Tensor:
7783
try:
7884
current = e.elem
79-
return (deepcopy(current.stride()), current._typed_storage()._cdata)
85+
return (
86+
deepcopy(current.stride()),
87+
current._typed_storage()._cdata,
88+
)
8089
except AttributeError as t:
8190
return None
8291
# Sparse CSR tensors do not have strides or storage
83-
elif (e.layout != torch.sparse_csr):
92+
elif e.layout != torch.sparse_csr:
8493
return (deepcopy(e.stride()), e._typed_storage()._cdata)
8594
return None
8695

8796
self.ops.append(func._schema.name)
8897

8998
# Clone and process arguments and outputs
9099
pre_arguments = normalize_function(
91-
func,
92-
args,
93-
kwargs,
94-
normalize_to_only_use_kwargs=True
100+
func, args, kwargs, normalize_to_only_use_kwargs=True
95101
).kwargs
96102

97103
c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
98-
cloned_arguments = {name : tree_map(unwrap, c_p_args.get(name)) for name in c_p_args}
99-
cloned_metadata = {name : tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments}
104+
cloned_arguments = {
105+
name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
106+
}
107+
cloned_metadata = {
108+
name: tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0])
109+
for name in pre_arguments
110+
}
100111

101112
out = func(*args, **kwargs)
102-
arguments = {name : tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments}
103-
tuple_out = out if isinstance(out, tuple) else (out, )
113+
arguments = {
114+
name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
115+
}
116+
tuple_out = out if isinstance(out, tuple) else (out,)
104117
tuple_out = tree_map(unwrap, tuple_out)
105118

106119
schema_info = SchemaInfo(func._schema)
@@ -116,17 +129,34 @@ def parse_metadata(e):
116129
after = arguments.get(name)
117130
for j in range(len(tuple_out)):
118131
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
119-
unsafe_ops = ('aten::_unsafe_view', 'aten::unsafe_split')
120-
if has_aliased(tuple_out[j], after) and func._schema.name not in unsafe_ops:
132+
unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
133+
if (
134+
has_aliased(tuple_out[j], after)
135+
and func._schema.name not in unsafe_ops
136+
):
121137
if not schema_info.may_contain_alias(
122138
SchemaArgument(SchemaArgType.output, j),
123-
SchemaArgument(SchemaArgType.input, i)):
124-
raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing')
139+
SchemaArgument(SchemaArgType.input, i),
140+
):
141+
raise RuntimeError(
142+
f"Argument {name} is not defined to alias output but was aliasing"
143+
)
125144
else:
126-
self.aliasing.append(Aliasing(func._schema.name, name, f"output_{j}"))
127-
if any(has_mutated(a, b, c) for a, b, c in zip(tree_flatten(before)[0], tree_flatten(after)[0], md)):
128-
if not schema_info.is_mutable(SchemaArgument(SchemaArgType.input, i)):
129-
raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated")
145+
self.aliasing.append(
146+
Aliasing(func._schema.name, name, f"output_{j}")
147+
)
148+
if any(
149+
has_mutated(a, b, c)
150+
for a, b, c in zip(
151+
tree_flatten(before)[0], tree_flatten(after)[0], md
152+
)
153+
):
154+
if not schema_info.is_mutable(
155+
SchemaArgument(SchemaArgType.input, i)
156+
):
157+
raise RuntimeError(
158+
f"Argument {name} is not defined as mutable but was mutated"
159+
)
130160
else:
131161
self.mutated.append(Mutation(func._schema.name, name))
132162

@@ -135,7 +165,8 @@ def parse_metadata(e):
135165
if has_aliased(tuple_out[i], tuple_out[j]):
136166
if not schema_info.may_contain_alias(
137167
SchemaArgument(SchemaArgType.output, i),
138-
SchemaArgument(SchemaArgType.output, j)):
139-
raise RuntimeError(f'Outputs {i} and {j} alias unexpectedly')
168+
SchemaArgument(SchemaArgType.output, j),
169+
):
170+
raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
140171

141172
return out

0 commit comments

Comments
 (0)