File tree Expand file tree Collapse file tree 3 files changed +43
-14
lines changed Expand file tree Collapse file tree 3 files changed +43
-14
lines changed Original file line number Diff line number Diff line change 1
1
load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension")
2
2
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3
+ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3
4
4
5
oncall("executorch")
5
6
@@ -57,3 +58,13 @@ python_library(
57
58
"//executorch/exir/emit:emit",
58
59
],
59
60
)
61
+
62
+ python_unittest(
63
+ name = "test_verifier",
64
+ srcs = ["test/test_verifier.py"],
65
+ deps = [
66
+ ":verifier",
67
+ "//caffe2:torch",
68
+ "//executorch/exir/dialects:lib",
69
+ ],
70
+ )
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+ from contextlib import contextmanager
9
+
10
+ from executorch .exir .dialects ._ops import ops
11
+ from torch ._export .verifier import SpecViolationError
12
+ from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
13
+
14
+ from ..verifier import EXIREdgeDialectVerifier
15
+
16
+
17
+ class TestEdgeDialectVerifier (unittest .TestCase ):
18
+ @contextmanager
19
+ def assertNotRaises (self , exc_type ):
20
+ try :
21
+ yield None
22
+ except exc_type :
23
+ raise self .failureException ("{} raised" .format (exc_type .__name__ ))
24
+
25
+ def test_edge_verifier_check_valid_op_succeed_given_custom_op (self ) -> None :
26
+ edge_op = ops .edge .quantized_decomposed .quantize_per_tensor .default
27
+ verifier = EXIREdgeDialectVerifier (check_edge_ops = True )
28
+ with self .assertNotRaises (SpecViolationError ):
29
+ verifier .check_valid_edge_op (edge_op )
30
+ verifier .check_valid_op (edge_op )
Original file line number Diff line number Diff line change @@ -144,6 +144,8 @@ class _EXIREdgeDialectVerifier(Verifier):
144
144
145
145
def __init__ (self ) -> None :
146
146
self .check_edge_ops = check_edge_ops
147
+ self .aten_op_verifier = EXIRATenDialectVerifier ()
148
+ self .check_valid_aten_op = self .aten_op_verifier .check_valid_op
147
149
148
150
if self .check_edge_ops :
149
151
self .check_valid_op = self .check_valid_edge_op
@@ -178,20 +180,6 @@ def check_valid_edge_op(self, op):
178
180
if isinstance (op , types .FunctionType ):
179
181
assert op .__name__ in ("alloc" ,)
180
182
181
- def check_valid_aten_op (self , op ) -> None :
182
- if isinstance (op , OpOverload ):
183
- if (
184
- torch .Tag .core not in op .tags # type: ignore[attr-defined]
185
- and torch .Tag .view_copy not in op .tags # type: ignore[attr-defined]
186
- ):
187
- # NOTE(qihan): whether view_copy operators are marked as canonical is still under
188
- # discussion.
189
- raise SpecViolationError (
190
- "Operator {}.{} is not Aten Canonical." .format (
191
- op .__module__ , op .__name__
192
- )
193
- )
194
-
195
183
def check_additional (self , gm : GraphModule ) -> None :
196
184
if not enable :
197
185
return
You can’t perform that action at this time.
0 commit comments