1
+ from collections import namedtuple
2
+ from copy import deepcopy
3
+ from itertools import combinations
4
+
1
5
import torch
2
- from torch .utils ._pytree import tree_flatten , tree_map
3
6
from torch .fx .operator_schemas import normalize_function
4
7
from torch .testing ._internal .jit_utils import clone_inputs
5
8
from torch .utils ._python_dispatch import TorchDispatchMode
6
- from itertools import combinations
7
- from collections import namedtuple
8
- from copy import deepcopy
9
+ from torch .utils ._pytree import tree_flatten , tree_map
9
10
10
11
# Named Tuples used within SchemaCheckMode
11
- Mutation = namedtuple (' Mutation' , [' op_name' , ' arg_name' ])
12
- Aliasing = namedtuple (' Aliasing' , [' op_name' , ' arg_name' , ' output_number' ])
12
+ Mutation = namedtuple (" Mutation" , [" op_name" , " arg_name" ])
13
+ Aliasing = namedtuple (" Aliasing" , [" op_name" , " arg_name" , " output_number" ])
13
14
14
15
# Simplified naming for C++ classes
15
16
SchemaArgument = torch ._C ._SchemaArgument
22
23
# - Checks for mutations on all inputs
23
24
# - Checks for aliasing on all inputs
24
25
26
+
25
27
class SchemaCheckMode (TorchDispatchMode ):
26
28
def __init__ (self ):
27
29
# Information recorded for testing purposes. For example:
@@ -42,12 +44,16 @@ def display_ops(self):
42
44
def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
43
45
def has_mutated (before , after , md ):
44
46
are_tensors = type (before ) == torch .Tensor and type (after ) == torch .Tensor
45
- if are_tensors and before .layout != torch .sparse_csr and after .layout != torch .sparse_csr :
47
+ if (
48
+ are_tensors
49
+ and before .layout != torch .sparse_csr
50
+ and after .layout != torch .sparse_csr
51
+ ):
46
52
return not (
47
- before .size () == after .size () and
48
- torch .allclose (before , after , equal_nan = True ) and
49
- md [0 ] == after .stride () and
50
- md [1 ] == after ._typed_storage ()._cdata
53
+ before .size () == after .size ()
54
+ and torch .allclose (before , after , equal_nan = True )
55
+ and md [0 ] == after .stride ()
56
+ and md [1 ] == after ._typed_storage ()._cdata
51
57
)
52
58
return False
53
59
@@ -76,31 +82,38 @@ def parse_metadata(e):
76
82
if not type (e ) == torch .Tensor :
77
83
try :
78
84
current = e .elem
79
- return (deepcopy (current .stride ()), current ._typed_storage ()._cdata )
85
+ return (
86
+ deepcopy (current .stride ()),
87
+ current ._typed_storage ()._cdata ,
88
+ )
80
89
except AttributeError as t :
81
90
return None
82
91
# Sparse CSR tensors do not have strides or storage
83
- elif ( e .layout != torch .sparse_csr ) :
92
+ elif e .layout != torch .sparse_csr :
84
93
return (deepcopy (e .stride ()), e ._typed_storage ()._cdata )
85
94
return None
86
95
87
96
self .ops .append (func ._schema .name )
88
97
89
98
# Clone and process arguments and outputs
90
99
pre_arguments = normalize_function (
91
- func ,
92
- args ,
93
- kwargs ,
94
- normalize_to_only_use_kwargs = True
100
+ func , args , kwargs , normalize_to_only_use_kwargs = True
95
101
).kwargs
96
102
97
103
c_p_args = dict (zip (pre_arguments .keys (), clone_inputs (pre_arguments .values ())))
98
- cloned_arguments = {name : tree_map (unwrap , c_p_args .get (name )) for name in c_p_args }
99
- cloned_metadata = {name : tree_map (parse_metadata , tree_flatten (pre_arguments .get (name ))[0 ]) for name in pre_arguments }
104
+ cloned_arguments = {
105
+ name : tree_map (unwrap , c_p_args .get (name )) for name in c_p_args
106
+ }
107
+ cloned_metadata = {
108
+ name : tree_map (parse_metadata , tree_flatten (pre_arguments .get (name ))[0 ])
109
+ for name in pre_arguments
110
+ }
100
111
101
112
out = func (* args , ** kwargs )
102
- arguments = {name : tree_map (unwrap , pre_arguments .get (name )) for name in pre_arguments }
103
- tuple_out = out if isinstance (out , tuple ) else (out , )
113
+ arguments = {
114
+ name : tree_map (unwrap , pre_arguments .get (name )) for name in pre_arguments
115
+ }
116
+ tuple_out = out if isinstance (out , tuple ) else (out ,)
104
117
tuple_out = tree_map (unwrap , tuple_out )
105
118
106
119
schema_info = SchemaInfo (func ._schema )
@@ -116,17 +129,34 @@ def parse_metadata(e):
116
129
after = arguments .get (name )
117
130
for j in range (len (tuple_out )):
118
131
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
119
- unsafe_ops = ('aten::_unsafe_view' , 'aten::unsafe_split' )
120
- if has_aliased (tuple_out [j ], after ) and func ._schema .name not in unsafe_ops :
132
+ unsafe_ops = ("aten::_unsafe_view" , "aten::unsafe_split" )
133
+ if (
134
+ has_aliased (tuple_out [j ], after )
135
+ and func ._schema .name not in unsafe_ops
136
+ ):
121
137
if not schema_info .may_contain_alias (
122
138
SchemaArgument (SchemaArgType .output , j ),
123
- SchemaArgument (SchemaArgType .input , i )):
124
- raise RuntimeError (f'Argument { name } is not defined to alias output but was aliasing' )
139
+ SchemaArgument (SchemaArgType .input , i ),
140
+ ):
141
+ raise RuntimeError (
142
+ f"Argument { name } is not defined to alias output but was aliasing"
143
+ )
125
144
else :
126
- self .aliasing .append (Aliasing (func ._schema .name , name , f"output_{ j } " ))
127
- if any (has_mutated (a , b , c ) for a , b , c in zip (tree_flatten (before )[0 ], tree_flatten (after )[0 ], md )):
128
- if not schema_info .is_mutable (SchemaArgument (SchemaArgType .input , i )):
129
- raise RuntimeError (f"Argument { name } is not defined as mutable but was mutated" )
145
+ self .aliasing .append (
146
+ Aliasing (func ._schema .name , name , f"output_{ j } " )
147
+ )
148
+ if any (
149
+ has_mutated (a , b , c )
150
+ for a , b , c in zip (
151
+ tree_flatten (before )[0 ], tree_flatten (after )[0 ], md
152
+ )
153
+ ):
154
+ if not schema_info .is_mutable (
155
+ SchemaArgument (SchemaArgType .input , i )
156
+ ):
157
+ raise RuntimeError (
158
+ f"Argument { name } is not defined as mutable but was mutated"
159
+ )
130
160
else :
131
161
self .mutated .append (Mutation (func ._schema .name , name ))
132
162
@@ -135,7 +165,8 @@ def parse_metadata(e):
135
165
if has_aliased (tuple_out [i ], tuple_out [j ]):
136
166
if not schema_info .may_contain_alias (
137
167
SchemaArgument (SchemaArgType .output , i ),
138
- SchemaArgument (SchemaArgType .output , j )):
139
- raise RuntimeError (f'Outputs { i } and { j } alias unexpectedly' )
168
+ SchemaArgument (SchemaArgType .output , j ),
169
+ ):
170
+ raise RuntimeError (f"Outputs { i } and { j } alias unexpectedly" )
140
171
141
172
return out
0 commit comments