@@ -41,6 +41,12 @@ def _check_tensors_are_contiguous(gm: GraphModule) -> None:
41
41
class EXIRATenDialectVerifierBase (Verifier ):
42
42
dialect = "OLD_EXIR_ATEN_DISABLED"
43
43
44
+ def __init__ (
45
+ self , exception_list : Optional [List [torch ._ops .OpOverload ]] = None
46
+ ) -> None :
47
+ super ().__init__ ()
48
+ self ._exception_list = exception_list if exception_list else []
49
+
44
50
def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
45
51
return (
46
52
torch .fx .GraphModule ,
@@ -67,18 +73,24 @@ class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
67
73
def check_valid_op (self , op ):
68
74
if isinstance (op , OpOverload ):
69
75
# TODO These special ops should be removable easily.
70
- if op .namespace in (
71
- "quantized_decomposed" ,
72
- "boltnn_nimble" ,
73
- "nimble" ,
74
- "quantized" ,
75
- ) or op in (
76
- torch .ops .aten .mkldnn_rnn_layer .default ,
77
- torch .ops .aten ._upsample_bilinear2d_aa .default ,
78
- torch .ops .aten .quantize_per_tensor .default ,
79
- torch .ops .aten .dequantize .self ,
80
- torch .ops .aten .max .default ,
81
- torch .ops .aten .full_like .default , # TODO(T183507359)
76
+ if (
77
+ op .namespace
78
+ in (
79
+ "quantized_decomposed" ,
80
+ "boltnn_nimble" ,
81
+ "nimble" ,
82
+ "quantized" ,
83
+ )
84
+ or op
85
+ in [
86
+ torch .ops .aten .mkldnn_rnn_layer .default ,
87
+ torch .ops .aten ._upsample_bilinear2d_aa .default ,
88
+ torch .ops .aten .quantize_per_tensor .default ,
89
+ torch .ops .aten .dequantize .self ,
90
+ torch .ops .aten .max .default ,
91
+ torch .ops .aten .full_like .default , # TODO(T183507359)
92
+ ]
93
+ + self ._exception_list
82
94
):
83
95
return
84
96
if torch .Tag .core not in op .tags and torch .Tag .view_copy not in op .tags :
@@ -139,19 +151,21 @@ def EXIREdgeDialectVerifier( # noqa: C901
139
151
check_edge_ops : bool = True ,
140
152
enable : bool = True ,
141
153
class_only : bool = False ,
154
+ exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
142
155
):
143
156
class _EXIREdgeDialectVerifier (Verifier ):
144
157
dialect = "EDGE"
145
158
146
159
def __init__ (self ) -> None :
147
160
self .check_edge_ops = check_edge_ops
148
- self .aten_op_verifier = EXIRATenDialectVerifier ()
161
+ self .aten_op_verifier = EXIRATenDialectVerifier (exception_list )
149
162
self .check_valid_aten_op = self .aten_op_verifier .check_valid_op
150
163
151
164
if self .check_edge_ops :
152
165
self .check_valid_op = self .check_valid_edge_op
153
166
else :
154
167
self .check_valid_op = self .check_valid_aten_op
168
+ self ._exception_list = exception_list if exception_list else []
155
169
156
170
def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
157
171
return (
@@ -167,7 +181,11 @@ def allowed_op_types(self):
167
181
def check_valid_edge_op (self , op ):
168
182
if not enable :
169
183
return
170
- if op in [operator .getitem , torch .ops .aten .sym_size .int ]:
184
+ if (
185
+ op
186
+ in [operator .getitem , torch .ops .aten .sym_size .int ]
187
+ + self ._exception_list
188
+ ):
171
189
return
172
190
173
191
if isinstance (op , OpOverload ) and not isinstance (op , EdgeOpOverload ):
0 commit comments