@@ -51,6 +51,12 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
51
51
class EXIRATenDialectVerifierBase (Verifier ):
52
52
dialect = "OLD_EXIR_ATEN_DISABLED"
53
53
54
+ def __init__ (
55
+ self , exception_list : Optional [List [torch ._ops .OpOverload ]] = None
56
+ ) -> None :
57
+ super ().__init__ ()
58
+ self ._exception_list = exception_list if exception_list else []
59
+
54
60
def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
55
61
return (
56
62
torch .fx .GraphModule ,
@@ -74,23 +80,33 @@ def __call__(self, *args, **kwargs):
74
80
class EXIRATenDialectVerifier (EXIRATenDialectVerifierBase ):
75
81
dialect = "OLD_EXIR_ATEN"
76
82
83
+ def _get_exception_list (self ) -> List [torch ._ops .OpOverload ]:
84
+ exception_list = [
85
+ torch .ops .aten .mkldnn_rnn_layer .default ,
86
+ torch .ops .aten ._upsample_bilinear2d_aa .default ,
87
+ torch .ops .aten .quantize_per_tensor .default ,
88
+ torch .ops .aten .dequantize .self ,
89
+ torch .ops .aten .max .default , # TODO(T188268054)
90
+ torch .ops .aten .min .default , # TODO(T188268054)
91
+ torch .ops .aten .full_like .default , # TODO(T183507359)
92
+ ]
93
+ exception_list += self ._exception_list
94
+
95
+ return exception_list
96
+
77
97
def check_valid_op (self , op ):
78
98
if isinstance (op , OpOverload ):
79
99
# TODO These special ops should be removable easily.
80
- if op .namespace in (
81
- "quantized_decomposed" ,
82
- "boltnn_nimble" ,
83
- "nimble" ,
84
- "quantized" ,
85
- "dim_order_ops" ,
86
- ) or op in (
87
- torch .ops .aten .mkldnn_rnn_layer .default ,
88
- torch .ops .aten ._upsample_bilinear2d_aa .default ,
89
- torch .ops .aten .quantize_per_tensor .default ,
90
- torch .ops .aten .dequantize .self ,
91
- torch .ops .aten .max .default , # TODO(T188268054)
92
- torch .ops .aten .min .default , # TODO(T188268054)
93
- torch .ops .aten .full_like .default , # TODO(T183507359)
100
+ if (
101
+ op .namespace
102
+ in [
103
+ "quantized_decomposed" ,
104
+ "boltnn_nimble" ,
105
+ "nimble" ,
106
+ "quantized" ,
107
+ "dim_order_ops" ,
108
+ ]
109
+ or op in self ._get_exception_list ()
94
110
):
95
111
return
96
112
if torch .Tag .core not in op .tags and torch .Tag .view_copy not in op .tags :
@@ -150,6 +166,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
150
166
def EXIREdgeDialectVerifier ( # noqa: C901
151
167
edge_compile_config : Optional [EdgeCompileConfig ] = None ,
152
168
class_only : bool = False ,
169
+ exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
153
170
):
154
171
class _EXIREdgeDialectVerifier (Verifier ):
155
172
dialect = "EDGE"
@@ -161,13 +178,14 @@ def __init__(self) -> None:
161
178
self .check_edge_ops = _edge_compile_config ._use_edge_ops
162
179
self .use_dim_order = not _edge_compile_config ._skip_dim_order
163
180
164
- self .aten_op_verifier = EXIRATenDialectVerifier ()
181
+ self .aten_op_verifier = EXIRATenDialectVerifier (exception_list )
165
182
self .check_valid_aten_op = self .aten_op_verifier .check_valid_op
166
183
167
184
if self .check_edge_ops :
168
185
self .check_valid_op = self .check_valid_edge_op
169
186
else :
170
187
self .check_valid_op = self .check_valid_aten_op
188
+ self ._exception_list = exception_list if exception_list else []
171
189
172
190
def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
173
191
return (
@@ -183,13 +201,17 @@ def allowed_op_types(self):
183
201
def check_valid_edge_op (self , op ):
184
202
if not self .enable :
185
203
return
186
- if op in [
187
- operator .getitem ,
188
- torch .ops .aten .sym_size .int ,
189
- torch .ops .aten .scalar_tensor .default ,
190
- torch .ops .aten ._assert_async .msg ,
191
- torch .ops .aten ._assert_scalar .default ,
192
- ]:
204
+ if (
205
+ op
206
+ in [
207
+ operator .getitem ,
208
+ torch .ops .aten .sym_size .int ,
209
+ torch .ops .aten .scalar_tensor .default ,
210
+ torch .ops .aten ._assert_async .msg ,
211
+ torch .ops .aten ._assert_scalar .default ,
212
+ ]
213
+ + self ._exception_list
214
+ ):
193
215
return
194
216
195
217
if isinstance (op , OpOverload ) and not isinstance (op , EdgeOpOverload ):
0 commit comments