20
20
is_channel_last_dim_order ,
21
21
is_contiguous_dim_order ,
22
22
)
23
+ from executorch .exir .pass_base import ExportPass
23
24
24
25
from torch .export import export
26
+
27
+ from torch .fx .passes .infra .pass_manager import PassManager
25
28
from torch .testing import FileCheck
26
29
from torch .utils ._pytree import tree_flatten
27
30
@@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
102
return t1 * t2
100
103
101
104
105
+ class AmbiguousDimOrderError (RuntimeError ):
106
+ pass
107
+
108
+
109
+ def assert_unambiguous_dim_order (gm ):
110
+ class ExampleNOPPass (ExportPass ):
111
+ """
112
+ Does nothing!
113
+ """
114
+
115
+ def call_operator (self , op , args , kwargs , meta ):
116
+ return super ().call_operator (
117
+ op ,
118
+ args ,
119
+ kwargs ,
120
+ meta ,
121
+ )
122
+
123
+ # This is an example of how one can detect ambiguous dim_order anywhere in the graph.
124
+ # You can be surgical and only detect it in the nodes you are interested in or something else.
125
+ def detect_ambiguity (gm ):
126
+ """
127
+ Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
128
+ """
129
+
130
+ def get_tensors (node : torch .fx .Node ) -> torch .Tensor :
131
+ val = node .meta ["val" ]
132
+ if isinstance (val , torch .Tensor ):
133
+ return [val ]
134
+ elif isinstance (val , (list , tuple )):
135
+ return [tensor for tensor in val if isinstance (tensor , torch .Tensor )]
136
+ return []
137
+
138
+ for node in gm .graph .nodes :
139
+ if node .op == "call_function" :
140
+ for tensor in get_tensors (node ):
141
+ # Let's make sure dim_order is not ambiguous, raise otherwise.
142
+ # This is raising because we can't do anything about it.
143
+ # The right course of follow up action is to ask user to try with a different example input.
144
+ try :
145
+ _ = tensor .dim_order (
146
+ ambiguity_check = [
147
+ torch .contiguous_format ,
148
+ torch .channels_last ,
149
+ ]
150
+ )
151
+ except Exception :
152
+ raise AmbiguousDimOrderError
153
+
154
+ # any pass or passes, just using MemoryFormatOpsPass as an example
155
+ dim_order_pass_manager = PassManager (passes = [ExampleNOPPass ()])
156
+ dim_order_pass_manager .add_checks (detect_ambiguity )
157
+ dim_order_pass_manager (gm )
158
+
159
+
102
160
class MemoryFormatOpsPassTestUtils :
103
161
@staticmethod
104
162
def memory_format_test_runner (
105
- test_class : unittest .TestCase , test_set : MemoryFormatTestSet
163
+ test_class : unittest .TestCase ,
164
+ test_set : MemoryFormatTestSet ,
165
+ check_unambiguous_dim_order : bool = False ,
106
166
):
107
167
before = export (
108
168
test_set .module , test_set .sample_input , strict = True
@@ -121,6 +181,9 @@ def memory_format_test_runner(
121
181
before , compile_config = EdgeCompileConfig (_skip_dim_order = False )
122
182
)
123
183
184
+ if check_unambiguous_dim_order :
185
+ assert_unambiguous_dim_order (epm .exported_program ().graph_module )
186
+
124
187
# check memory format ops, if needed
125
188
if test_set .op_level_check :
126
189
aten_op_str , edge_op_str = MemoryFormatOps2Str [test_set .op ]
0 commit comments