Skip to content

Commit de76a53

Browse files
tarun292facebook-github-bot
authored andcommitted
Add support for exception list in EXIRATenDialectVerifierBase
Summary: Adding support for an exception list in EXIRATenDialectVerifierBase to support the implementation of `to_edge_transform_and_lower`. We'll pass in the list of ops that have been registered to not be decomposed into this verifier so that it skips these ops. Differential Revision: D56560549
1 parent 2a4fcb4 commit de76a53

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

exir/verification/verifier.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def _check_tensors_are_contiguous(gm: GraphModule) -> None:
4141
class EXIRATenDialectVerifierBase(Verifier):
4242
dialect = "OLD_EXIR_ATEN_DISABLED"
4343

44+
def __init__(
45+
self, exception_list: Optional[List[torch._ops.OpOverload]] = None
46+
) -> None:
47+
super().__init__()
48+
self._exception_list = exception_list if exception_list else []
49+
4450
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
4551
return (
4652
torch.fx.GraphModule,
@@ -67,18 +73,24 @@ class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
6773
def check_valid_op(self, op):
6874
if isinstance(op, OpOverload):
6975
# TODO These special ops should be removable easily.
70-
if op.namespace in (
71-
"quantized_decomposed",
72-
"boltnn_nimble",
73-
"nimble",
74-
"quantized",
75-
) or op in (
76-
torch.ops.aten.mkldnn_rnn_layer.default,
77-
torch.ops.aten._upsample_bilinear2d_aa.default,
78-
torch.ops.aten.quantize_per_tensor.default,
79-
torch.ops.aten.dequantize.self,
80-
torch.ops.aten.max.default,
81-
torch.ops.aten.full_like.default, # TODO(T183507359)
76+
if (
77+
op.namespace
78+
in (
79+
"quantized_decomposed",
80+
"boltnn_nimble",
81+
"nimble",
82+
"quantized",
83+
)
84+
or op
85+
in [
86+
torch.ops.aten.mkldnn_rnn_layer.default,
87+
torch.ops.aten._upsample_bilinear2d_aa.default,
88+
torch.ops.aten.quantize_per_tensor.default,
89+
torch.ops.aten.dequantize.self,
90+
torch.ops.aten.max.default,
91+
torch.ops.aten.full_like.default, # TODO(T183507359)
92+
]
93+
+ self._exception_list
8294
):
8395
return
8496
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
@@ -139,19 +151,21 @@ def EXIREdgeDialectVerifier( # noqa: C901
139151
check_edge_ops: bool = True,
140152
enable: bool = True,
141153
class_only: bool = False,
154+
exception_list: Optional[List[torch._ops.OpOverload]] = None,
142155
):
143156
class _EXIREdgeDialectVerifier(Verifier):
144157
dialect = "EDGE"
145158

146159
def __init__(self) -> None:
147160
self.check_edge_ops = check_edge_ops
148-
self.aten_op_verifier = EXIRATenDialectVerifier()
161+
self.aten_op_verifier = EXIRATenDialectVerifier(exception_list)
149162
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
150163

151164
if self.check_edge_ops:
152165
self.check_valid_op = self.check_valid_edge_op
153166
else:
154167
self.check_valid_op = self.check_valid_aten_op
168+
self._exception_list = exception_list if exception_list else []
155169

156170
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
157171
return (
@@ -167,7 +181,11 @@ def allowed_op_types(self):
167181
def check_valid_edge_op(self, op):
168182
if not enable:
169183
return
170-
if op in [operator.getitem, torch.ops.aten.sym_size.int]:
184+
if (
185+
op
186+
in [operator.getitem, torch.ops.aten.sym_size.int]
187+
+ self._exception_list
188+
):
171189
return
172190

173191
if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):

0 commit comments

Comments
 (0)