Skip to content

Commit c3d1680

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for exception list in EXIRATenDialectVerifierBase (#3481)
Summary: Pull Request resolved: #3481 Adding support for an exception list in EXIRATenDialectVerifierBase to support the implementation of `to_edge_transform_and_lower`. We'll pass in the list of ops that have been registered to not be decomposed into this verifier so that it skips these ops. Reviewed By: larryliu0820 Differential Revision: D56560549 fbshipit-source-id: 88e7f52d8ba97b9caf11aac76fecfed4a0602217
1 parent 9d4727d commit c3d1680

File tree

1 file changed

+44
-22
lines changed

1 file changed

+44
-22
lines changed

exir/verification/verifier.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None:
5151
class EXIRATenDialectVerifierBase(Verifier):
5252
dialect = "OLD_EXIR_ATEN_DISABLED"
5353

54+
def __init__(
55+
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
56+
) -> None:
57+
super().__init__()
58+
self._exception_list = exception_list if exception_list else []
59+
5460
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
5561
return (
5662
torch.fx.GraphModule,
@@ -74,23 +80,33 @@ def __call__(self, *args, **kwargs):
7480
class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
7581
dialect = "OLD_EXIR_ATEN"
7682

83+
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
84+
exception_list = [
85+
torch.ops.aten.mkldnn_rnn_layer.default,
86+
torch.ops.aten._upsample_bilinear2d_aa.default,
87+
torch.ops.aten.quantize_per_tensor.default,
88+
torch.ops.aten.dequantize.self,
89+
torch.ops.aten.max.default, # TODO(T188268054)
90+
torch.ops.aten.min.default, # TODO(T188268054)
91+
torch.ops.aten.full_like.default, # TODO(T183507359)
92+
]
93+
exception_list += self._exception_list
94+
95+
return exception_list
96+
7797
def check_valid_op(self, op):
7898
if isinstance(op, OpOverload):
7999
# TODO These special ops should be removable easily.
80-
if op.namespace in (
81-
"quantized_decomposed",
82-
"boltnn_nimble",
83-
"nimble",
84-
"quantized",
85-
"dim_order_ops",
86-
) or op in (
87-
torch.ops.aten.mkldnn_rnn_layer.default,
88-
torch.ops.aten._upsample_bilinear2d_aa.default,
89-
torch.ops.aten.quantize_per_tensor.default,
90-
torch.ops.aten.dequantize.self,
91-
torch.ops.aten.max.default, # TODO(T188268054)
92-
torch.ops.aten.min.default, # TODO(T188268054)
93-
torch.ops.aten.full_like.default, # TODO(T183507359)
100+
if (
101+
op.namespace
102+
in [
103+
"quantized_decomposed",
104+
"boltnn_nimble",
105+
"nimble",
106+
"quantized",
107+
"dim_order_ops",
108+
]
109+
or op in self._get_exception_list()
94110
):
95111
return
96112
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
@@ -150,6 +166,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
150166
def EXIREdgeDialectVerifier( # noqa: C901
151167
edge_compile_config: Optional[EdgeCompileConfig] = None,
152168
class_only: bool = False,
169+
exception_list: Optional[List[torch._ops.OpOverload]] = None,
153170
):
154171
class _EXIREdgeDialectVerifier(Verifier):
155172
dialect = "EDGE"
@@ -161,13 +178,14 @@ def __init__(self) -> None:
161178
self.check_edge_ops = _edge_compile_config._use_edge_ops
162179
self.use_dim_order = not _edge_compile_config._skip_dim_order
163180

164-
self.aten_op_verifier = EXIRATenDialectVerifier()
181+
self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
165182
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
166183

167184
if self.check_edge_ops:
168185
self.check_valid_op = self.check_valid_edge_op
169186
else:
170187
self.check_valid_op = self.check_valid_aten_op
188+
self._exception_list = exception_list if exception_list else []
171189

172190
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
173191
return (
@@ -183,13 +201,17 @@ def allowed_op_types(self):
183201
def check_valid_edge_op(self, op):
184202
if not self.enable:
185203
return
186-
if op in [
187-
operator.getitem,
188-
torch.ops.aten.sym_size.int,
189-
torch.ops.aten.scalar_tensor.default,
190-
torch.ops.aten._assert_async.msg,
191-
torch.ops.aten._assert_scalar.default,
192-
]:
204+
if (
205+
op
206+
in [
207+
operator.getitem,
208+
torch.ops.aten.sym_size.int,
209+
torch.ops.aten.scalar_tensor.default,
210+
torch.ops.aten._assert_async.msg,
211+
torch.ops.aten._assert_scalar.default,
212+
]
213+
+ self._exception_list
214+
):
193215
return
194216

195217
if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):

0 commit comments

Comments
 (0)