Skip to content

Commit b65983a

Browse files
angelayifacebook-github-bot
authored andcommitted
Verifier for exported program
Summary: Added a verifier for the graph signature in a exported program Differential Revision: D48926643
1 parent 2f21fe6 commit b65983a

File tree

2 files changed

+41
-50
lines changed

2 files changed

+41
-50
lines changed

exir/tests/test_verification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
145145
self.assertEqual(len(res), len(res_interp))
146146
self.assertTrue(torch.allclose(res, res_interp))
147147

148+
149+
class TestEdgeVerification(unittest.TestCase):
148150
def test_edge_happy(self) -> None:
149151
class TestModel(torch.nn.Module):
150152
def __init__(self):

exir/verification/verifier.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
8+
import operator
79
from typing import List, Optional
810

911
import torch
10-
from executorch.exir.delegate import executorch_call_delegate
1112
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1213
from executorch.exir.error import ExportError, ExportErrorType
1314
from executorch.exir.verification.arg_validator import (
@@ -16,8 +17,7 @@
1617
)
1718

1819
from torch._export.verifier import (
19-
_check_has_fake_tensor,
20-
_check_tensors_are_contiguous,
20+
_check_val,
2121
ATenDialectVerifier,
2222
SpecViolationError,
2323
Verifier,
@@ -28,30 +28,22 @@
2828

2929

3030
ALLOWED_META_KEYS = {"spec", "stack_trace"}
31-
VALID_BUILTIN_FUNCS = [
32-
executorch_call_delegate,
33-
]
3431

3532

36-
class EXIRATenDialectVerifier(ATenDialectVerifier):
37-
def valid_builtin_funcs(self):
38-
builtin_funcs = super().valid_builtin_funcs()
39-
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
40-
return builtin_funcs
41-
42-
# TODO(angelayi): Delete this function when we migrate all tests to
43-
# because right now old tracer does not add ["val"] metadata
44-
def check_valid(self, gm: GraphModule) -> None: # noqa: C901
45-
46-
for node in gm.graph.nodes:
47-
if node.op in {"call_module", "call_method"}:
33+
def _check_tensors_are_contiguous(gm: GraphModule) -> None:
34+
# Tensors be of contiguous format
35+
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
36+
if isinstance(param, torch.Tensor):
37+
if not param.is_contiguous():
4838
raise SpecViolationError(
49-
"call_module is not valid: got a class '{}' ".format(node.target),
39+
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
5040
)
5141

52-
if node.op == "call_function":
53-
if node.target not in self.valid_builtin_funcs():
54-
self.check_valid_op(node.target)
42+
43+
class EXIRATenDialectVerifier(ATenDialectVerifier):
44+
def _check_attribute(self, mod: torch.fx.GraphModule, target: str) -> None:
45+
# TODO: remove this once Executorch fully migrates to torch.export
46+
pass
5547

5648

5749
def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
@@ -97,15 +89,22 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
9789

9890

9991
class EXIREdgeDialectVerifier(Verifier):
100-
def __init__(self, check_edge_ops: bool = False) -> None:
92+
def __init__(self, check_edge_ops: bool = True) -> None:
10193
self.check_edge_ops = check_edge_ops
10294

103-
def valid_builtin_funcs(self):
104-
builtin_funcs = super().valid_builtin_funcs()
105-
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
106-
return builtin_funcs
95+
if self.check_edge_ops:
96+
self.check_valid_op = self.check_valid_edge_op
97+
else:
98+
self.check_valid_op = self.check_valid_aten_op
99+
100+
def _check_attribute(self, mod: torch.fx.GraphModule, target: str) -> None:
101+
# TODO: remove this once Executorch fully migrates to torch.export
102+
pass
107103

108104
def check_valid_edge_op(self, op):
105+
if op is operator.getitem:
106+
return
107+
109108
if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
110109
raise SpecViolationError(
111110
"Operator {}.{} is not an Edge operator.".format(
@@ -116,38 +115,28 @@ def check_valid_edge_op(self, op):
116115
def check_valid_aten_op(self, op) -> None:
117116
super().check_valid_op(op)
118117

119-
op_name = op.name if hasattr(op, "name") else op.__name__
120-
121-
if not isinstance(op, OpOverload):
122-
raise SpecViolationError(
123-
"Operator '{}' is not a registered Op".format(op_name),
124-
)
125-
126-
if (
127-
torch.Tag.core not in op.tags # type: ignore[attr-defined]
128-
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
129-
):
130-
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
131-
# discussion.
132-
raise SpecViolationError(
133-
"Operator {}.{} is not Aten Canonical.".format(
134-
op.__module__, op.__name__
118+
if isinstance(op, OpOverload):
119+
if (
120+
torch.Tag.core not in op.tags # type: ignore[attr-defined]
121+
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
122+
):
123+
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
124+
# discussion.
125+
raise SpecViolationError(
126+
"Operator {}.{} is not Aten Canonical.".format(
127+
op.__module__, op.__name__
128+
)
135129
)
136-
)
137130

138-
def check_valid(self, gm: GraphModule) -> None:
131+
def check_additional(self, gm: GraphModule) -> None:
139132
if self.check_edge_ops:
140-
self.check_valid_op = self.check_valid_edge_op
141-
super().check_valid(gm)
142133
_check_tensors_are_contiguous(gm)
143134
_check_tensor_args_matching_op_allowed_dtype(gm)
144-
else:
145-
self.check_valid_op = self.check_valid_aten_op
146135

147136
# Additionally, edge dialect's operator must have same input dtype
148137
for n in gm.graph.nodes:
149138
if n.op == "call_function" and isinstance(n.target, OpOverload):
150-
_check_has_fake_tensor(n)
139+
_check_val(n)
151140
dtypes = set()
152141
for arg in n.args:
153142
if isinstance(arg, torch.Tensor):

0 commit comments

Comments
 (0)