@@ -52,12 +52,6 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
52
52
class EXIRATenDialectVerifierBase (Verifier ):
53
53
dialect = "OLD_EXIR_ATEN_DISABLED"
54
54
55
- def __init__ (
56
- self , exception_list : Optional [List [torch ._ops .OpOverload ]] = None
57
- ) -> None :
58
- super ().__init__ ()
59
- self ._exception_list = exception_list if exception_list else []
60
-
61
55
def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
62
56
return (
63
57
torch .fx .GraphModule ,
@@ -78,38 +72,68 @@ def __call__(self, *args, **kwargs):
78
72
raise RuntimeError ("" )
79
73
80
74
81
- class EXIRATenDialectVerifier (EXIRATenDialectVerifierBase ):
82
- dialect = "OLD_EXIR_ATEN"
75
+ def EXIRATenDialectVerifier ( # noqa: C901
76
+ edge_compile_config : Optional [EdgeCompileConfig ] = None ,
77
+ class_only : bool = False ,
78
+ exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
79
+ ):
80
+ """
81
+ Returns a verifier class that runs ATen dialect specific checks on the graph module.
82
+ """
83
+ # merge the exception list from edge_compile_config and exception_list
84
+ if edge_compile_config and edge_compile_config ._core_aten_ops_exception_list :
85
+ exception_list = edge_compile_config ._core_aten_ops_exception_list + (
86
+ exception_list or []
87
+ )
83
88
84
- def _get_exception_list (self ) -> List [torch ._ops .OpOverload ]:
85
- exception_list = [
86
- torch .ops .aten .mkldnn_rnn_layer .default ,
87
- torch .ops .aten ._upsample_bilinear2d_aa .default ,
88
- torch .ops .aten .quantize_per_tensor .default ,
89
- torch .ops .aten .dequantize .self ,
90
- torch .ops .aten .max .default , # TODO(T188268054)
91
- torch .ops .aten .min .default , # TODO(T188268054)
92
- torch .ops .aten .full_like .default , # TODO(T183507359)
93
- ]
94
- exception_list += self ._exception_list
89
+ class _EXIRATenDialectVerifier (EXIRATenDialectVerifierBase ):
90
+ dialect = "OLD_EXIR_ATEN"
95
91
96
- return exception_list
92
+ def __init__ (self ) -> None :
93
+ super ().__init__ ()
94
+ # Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
95
+ self ._exception_list = exception_list if exception_list else []
97
96
98
- def check_valid_op (self , op ) :
99
- if isinstance ( op , OpOverload ):
100
- # TODO These special ops should be removable easily.
101
- if op . namespace != "aten" or op in self . _get_exception_list ():
102
- return
103
- if torch . Tag . core not in op . tags and torch .Tag . view_copy not in op . tags :
104
- # NOTE(qihan): whether view_copy operators are marked as canonical is still under
105
- # discussion.
106
- raise SpecViolationError (
107
- f"Operator { op . __module__ } . { op . __name__ } is not Aten Canonical."
108
- )
97
+ def _get_exception_list (self ) -> List [ torch . _ops . OpOverload ] :
98
+ exception_list = [
99
+ torch . ops . aten . mkldnn_rnn_layer . default ,
100
+ torch . ops . aten . _upsample_bilinear2d_aa . default ,
101
+ torch . ops . aten . quantize_per_tensor . default ,
102
+ torch .ops . aten . dequantize . self ,
103
+ torch . ops . aten . max . default , # TODO(T188268054)
104
+ torch . ops . aten . min . default , # TODO(T188268054)
105
+ torch . ops . aten . full_like . default , # TODO(T183507359)
106
+ ]
107
+ exception_list += self . _exception_list
109
108
109
+ return exception_list
110
110
111
- def get_aten_verifier (enable : bool = True ):
112
- return EXIRATenDialectVerifier if enable else EXIRATenDialectVerifierBase
111
+ def check_valid_op (self , op ):
112
+ if isinstance (op , OpOverload ):
113
+ # TODO These special ops should be removable easily.
114
+ if op .namespace != "aten" or op in self ._get_exception_list ():
115
+ return
116
+ if torch .Tag .core not in op .tags and torch .Tag .view_copy not in op .tags :
117
+ # NOTE(qihan): whether view_copy operators are marked as canonical is still under
118
+ # discussion.
119
+ raise SpecViolationError (
120
+ f"Operator { op .__module__ } .{ op .__name__ } is not Aten Canonical."
121
+ )
122
+
123
+ ret = _EXIRATenDialectVerifier
124
+ if not class_only :
125
+ ret = ret ()
126
+ return ret
127
+
128
+
129
+ def get_aten_verifier (config : EdgeCompileConfig ):
130
+ return (
131
+ EXIRATenDialectVerifier (
132
+ class_only = True , exception_list = config ._core_aten_ops_exception_list
133
+ )
134
+ if config ._check_ir_validity
135
+ else EXIRATenDialectVerifierBase
136
+ )
113
137
114
138
115
139
def _get_inputs (graph_module : GraphModule ) -> List [Optional [FakeTensor ]]:
@@ -160,6 +184,12 @@ def EXIREdgeDialectVerifier( # noqa: C901
160
184
class_only : bool = False ,
161
185
exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
162
186
):
187
+ # merge the exception list from edge_compile_config and exception_list
188
+ if edge_compile_config and edge_compile_config ._core_aten_ops_exception_list :
189
+ exception_list = edge_compile_config ._core_aten_ops_exception_list + (
190
+ exception_list or []
191
+ )
192
+
163
193
class _EXIREdgeDialectVerifier (Verifier ):
164
194
dialect = "EDGE"
165
195
@@ -170,7 +200,9 @@ def __init__(self) -> None:
170
200
self .check_edge_ops = _edge_compile_config ._use_edge_ops
171
201
self .use_dim_order = not _edge_compile_config ._skip_dim_order
172
202
173
- self .aten_op_verifier = EXIRATenDialectVerifier (exception_list )
203
+ self .aten_op_verifier = EXIRATenDialectVerifier (
204
+ exception_list = exception_list
205
+ )
174
206
self .check_valid_aten_op = self .aten_op_verifier .check_valid_op
175
207
176
208
if self .check_edge_ops :
0 commit comments