Skip to content

Commit 60df682

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Cleanup EXIREdgeDialectVerifier to use the same logic to verify core (#1504)
Summary: Pull Request resolved: #1504 Making sure EXIREdgeDialectVerifier is able to verify core ATen ops correctly. Getting rid of forked logic of `check_valid_aten_op`. Reviewed By: iseeyuan Differential Revision: D52492252 fbshipit-source-id: 0f0ad2adea99753ce8a804e5dff349213036cbbb
1 parent a81c2d4 commit 60df682

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

exir/verification/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension")
22
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
34

45
oncall("executorch")
56

@@ -57,3 +58,13 @@ python_library(
5758
"//executorch/exir/emit:emit",
5859
],
5960
)
61+
62+
python_unittest(
63+
name = "test_verifier",
64+
srcs = ["test/test_verifier.py"],
65+
deps = [
66+
":verifier",
67+
"//caffe2:torch",
68+
"//executorch/exir/dialects:lib",
69+
],
70+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from contextlib import contextmanager
9+
10+
from executorch.exir.dialects._ops import ops
11+
from torch._export.verifier import SpecViolationError
12+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
13+
14+
from ..verifier import EXIREdgeDialectVerifier
15+
16+
17+
class TestEdgeDialectVerifier(unittest.TestCase):
18+
@contextmanager
19+
def assertNotRaises(self, exc_type):
20+
try:
21+
yield None
22+
except exc_type:
23+
raise self.failureException("{} raised".format(exc_type.__name__))
24+
25+
def test_edge_verifier_check_valid_op_succeed_given_custom_op(self) -> None:
26+
edge_op = ops.edge.quantized_decomposed.quantize_per_tensor.default
27+
verifier = EXIREdgeDialectVerifier(check_edge_ops=True)
28+
with self.assertNotRaises(SpecViolationError):
29+
verifier.check_valid_edge_op(edge_op)
30+
verifier.check_valid_op(edge_op)

exir/verification/verifier.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ class _EXIREdgeDialectVerifier(Verifier):
144144

145145
def __init__(self) -> None:
146146
self.check_edge_ops = check_edge_ops
147+
self.aten_op_verifier = EXIRATenDialectVerifier()
148+
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
147149

148150
if self.check_edge_ops:
149151
self.check_valid_op = self.check_valid_edge_op
@@ -178,20 +180,6 @@ def check_valid_edge_op(self, op):
178180
if isinstance(op, types.FunctionType):
179181
assert op.__name__ in ("alloc",)
180182

181-
def check_valid_aten_op(self, op) -> None:
182-
if isinstance(op, OpOverload):
183-
if (
184-
torch.Tag.core not in op.tags # type: ignore[attr-defined]
185-
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
186-
):
187-
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
188-
# discussion.
189-
raise SpecViolationError(
190-
"Operator {}.{} is not Aten Canonical.".format(
191-
op.__module__, op.__name__
192-
)
193-
)
194-
195183
def check_additional(self, gm: GraphModule) -> None:
196184
if not enable:
197185
return

0 commit comments

Comments
 (0)