Skip to content

Commit a5e118f

Browse files
angelayifacebook-github-bot
authored andcommitted
Verifier for exported program (#292)
Summary: X-link: pytorch/pytorch#109519 Added a verifier for the graph signature in a exported program Differential Revision: D48926643
1 parent c52000a commit a5e118f

File tree

3 files changed

+40
-50
lines changed

3 files changed

+40
-50
lines changed

exir/tests/test_verification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
145145
self.assertEqual(len(res), len(res_interp))
146146
self.assertTrue(torch.allclose(res, res_interp))
147147

148+
149+
class TestEdgeVerification(unittest.TestCase):
148150
def test_edge_happy(self) -> None:
149151
class TestModel(torch.nn.Module):
150152
def __init__(self):

exir/verification/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ python_library(
5151
],
5252
deps = [
5353
"//caffe2:torch",
54-
"//executorch/exir:delegate",
5554
"//executorch/exir:error",
55+
"//executorch/exir:lowered_backend_module",
5656
"//executorch/exir/dialects/edge:lib",
5757
"//executorch/exir/emit:emit",
5858
],

exir/verification/verifier.py

Lines changed: 37 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Optional
7+
import itertools
8+
import operator
9+
from typing import Any, List, Optional, Tuple, Type
810

911
import torch
10-
from executorch.exir.delegate import executorch_call_delegate
1112
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1213
from executorch.exir.error import ExportError, ExportErrorType
14+
from executorch.exir.lowered_backend_module import LoweredBackendModule
1315
from executorch.exir.verification.arg_validator import (
1416
EdgeOpArgValidator,
1517
RunHigherOrderOperatorError,
1618
)
1719

1820
from torch._export.verifier import (
1921
_check_has_fake_tensor,
20-
_check_tensors_are_contiguous,
2122
ATenDialectVerifier,
2223
SpecViolationError,
2324
Verifier,
@@ -28,30 +29,21 @@
2829

2930

3031
ALLOWED_META_KEYS = {"spec", "stack_trace"}
31-
VALID_BUILTIN_FUNCS = [
32-
executorch_call_delegate,
33-
]
3432

3533

36-
class EXIRATenDialectVerifier(ATenDialectVerifier):
37-
def valid_builtin_funcs(self):
38-
builtin_funcs = super().valid_builtin_funcs()
39-
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
40-
return builtin_funcs
41-
42-
# TODO(angelayi): Delete this function when we migrate all tests to
43-
# because right now old tracer does not add ["val"] metadata
44-
def check_valid(self, gm: GraphModule) -> None: # noqa: C901
45-
46-
for node in gm.graph.nodes:
47-
if node.op in {"call_module", "call_method"}:
34+
def _check_tensors_are_contiguous(gm: GraphModule) -> None:
35+
# Tensors be of contiguous format
36+
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
37+
if isinstance(param, torch.Tensor):
38+
if not param.is_contiguous():
4839
raise SpecViolationError(
49-
"call_module is not valid: got a class '{}' ".format(node.target),
40+
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
5041
)
5142

52-
if node.op == "call_function":
53-
if node.target not in self.valid_builtin_funcs():
54-
self.check_valid_op(node.target)
43+
44+
class EXIRATenDialectVerifier(ATenDialectVerifier):
45+
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
46+
return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor)
5547

5648

5749
def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
@@ -97,15 +89,21 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
9789

9890

9991
class EXIREdgeDialectVerifier(Verifier):
100-
def __init__(self, check_edge_ops: bool = False) -> None:
92+
def __init__(self, check_edge_ops: bool = True) -> None:
10193
self.check_edge_ops = check_edge_ops
10294

103-
def valid_builtin_funcs(self):
104-
builtin_funcs = super().valid_builtin_funcs()
105-
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
106-
return builtin_funcs
95+
if self.check_edge_ops:
96+
self.check_valid_op = self.check_valid_edge_op
97+
else:
98+
self.check_valid_op = self.check_valid_aten_op
99+
100+
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
101+
return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor)
107102

108103
def check_valid_edge_op(self, op):
104+
if op in [operator.getitem]:
105+
return
106+
109107
if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
110108
raise SpecViolationError(
111109
"Operator {}.{} is not an Edge operator.".format(
@@ -116,33 +114,23 @@ def check_valid_edge_op(self, op):
116114
def check_valid_aten_op(self, op) -> None:
117115
super().check_valid_op(op)
118116

119-
op_name = op.name if hasattr(op, "name") else op.__name__
120-
121-
if not isinstance(op, OpOverload):
122-
raise SpecViolationError(
123-
"Operator '{}' is not a registered Op".format(op_name),
124-
)
125-
126-
if (
127-
torch.Tag.core not in op.tags # type: ignore[attr-defined]
128-
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
129-
):
130-
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
131-
# discussion.
132-
raise SpecViolationError(
133-
"Operator {}.{} is not Aten Canonical.".format(
134-
op.__module__, op.__name__
117+
if isinstance(op, OpOverload):
118+
if (
119+
torch.Tag.core not in op.tags # type: ignore[attr-defined]
120+
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
121+
):
122+
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
123+
# discussion.
124+
raise SpecViolationError(
125+
"Operator {}.{} is not Aten Canonical.".format(
126+
op.__module__, op.__name__
127+
)
135128
)
136-
)
137129

138-
def check_valid(self, gm: GraphModule) -> None:
130+
def check_additional(self, gm: GraphModule) -> None:
139131
if self.check_edge_ops:
140-
self.check_valid_op = self.check_valid_edge_op
141-
super().check_valid(gm)
142132
_check_tensors_are_contiguous(gm)
143133
_check_tensor_args_matching_op_allowed_dtype(gm)
144-
else:
145-
self.check_valid_op = self.check_valid_aten_op
146134

147135
# Additionally, edge dialect's operator must have same input dtype
148136
for n in gm.graph.nodes:

0 commit comments

Comments
 (0)