4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import List , Optional
7
+ import itertools
8
+ import operator
9
+ from typing import Any , List , Optional , Tuple , Type
8
10
9
11
import torch
10
- from executorch .exir .delegate import executorch_call_delegate
11
12
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
12
13
from executorch .exir .error import ExportError , ExportErrorType
14
+ from executorch .exir .lowered_backend_module import LoweredBackendModule
13
15
from executorch .exir .verification .arg_validator import (
14
16
EdgeOpArgValidator ,
15
17
RunHigherOrderOperatorError ,
16
18
)
17
19
18
20
from torch ._export .verifier import (
19
21
_check_has_fake_tensor ,
20
- _check_tensors_are_contiguous ,
21
22
ATenDialectVerifier ,
22
23
SpecViolationError ,
23
24
Verifier ,
28
29
29
30
30
31
ALLOWED_META_KEYS = {"spec" , "stack_trace" }
31
- VALID_BUILTIN_FUNCS = [
32
- executorch_call_delegate ,
33
- ]
34
32
35
33
36
- class EXIRATenDialectVerifier (ATenDialectVerifier ):
37
- def valid_builtin_funcs (self ):
38
- builtin_funcs = super ().valid_builtin_funcs ()
39
- builtin_funcs .extend (VALID_BUILTIN_FUNCS )
40
- return builtin_funcs
41
-
42
- # TODO(angelayi): Delete this function when we migrate all tests to
43
- # because right now old tracer does not add ["val"] metadata
44
- def check_valid (self , gm : GraphModule ) -> None : # noqa: C901
45
-
46
- for node in gm .graph .nodes :
47
- if node .op in {"call_module" , "call_method" }:
34
+ def _check_tensors_are_contiguous (gm : GraphModule ) -> None :
35
+ # Tensors be of contiguous format
36
+ for name , param in itertools .chain (gm .named_parameters (), gm .named_buffers ()):
37
+ if isinstance (param , torch .Tensor ):
38
+ if not param .is_contiguous ():
48
39
raise SpecViolationError (
49
- "call_module is not valid: got a class '{}' " . format ( node . target ),
40
+ f"Tensors in Aten dialect must be contiguous, { name } is not contiguous"
50
41
)
51
42
52
- if node .op == "call_function" :
53
- if node .target not in self .valid_builtin_funcs ():
54
- self .check_valid_op (node .target )
43
+
44
+ class EXIRATenDialectVerifier (ATenDialectVerifier ):
45
+ def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
46
+ return (torch .fx .GraphModule , LoweredBackendModule , torch .Tensor )
55
47
56
48
57
49
def _get_inputs (graph_module : GraphModule ) -> List [Optional [FakeTensor ]]:
@@ -97,15 +89,21 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
97
89
98
90
99
91
class EXIREdgeDialectVerifier (Verifier ):
100
- def __init__ (self , check_edge_ops : bool = False ) -> None :
92
+ def __init__ (self , check_edge_ops : bool = True ) -> None :
101
93
self .check_edge_ops = check_edge_ops
102
94
103
- def valid_builtin_funcs (self ):
104
- builtin_funcs = super ().valid_builtin_funcs ()
105
- builtin_funcs .extend (VALID_BUILTIN_FUNCS )
106
- return builtin_funcs
95
+ if self .check_edge_ops :
96
+ self .check_valid_op = self .check_valid_edge_op
97
+ else :
98
+ self .check_valid_op = self .check_valid_aten_op
99
+
100
+ def allowed_getattr_types (self ) -> Tuple [Type [Any ], ...]:
101
+ return (torch .fx .GraphModule , LoweredBackendModule , torch .Tensor )
107
102
108
103
def check_valid_edge_op (self , op ):
104
+ if op in [operator .getitem ]:
105
+ return
106
+
109
107
if isinstance (op , OpOverload ) and not isinstance (op , EdgeOpOverload ):
110
108
raise SpecViolationError (
111
109
"Operator {}.{} is not an Edge operator." .format (
@@ -116,33 +114,23 @@ def check_valid_edge_op(self, op):
116
114
def check_valid_aten_op (self , op ) -> None :
117
115
super ().check_valid_op (op )
118
116
119
- op_name = op .name if hasattr (op , "name" ) else op .__name__
120
-
121
- if not isinstance (op , OpOverload ):
122
- raise SpecViolationError (
123
- "Operator '{}' is not a registered Op" .format (op_name ),
124
- )
125
-
126
- if (
127
- torch .Tag .core not in op .tags # type: ignore[attr-defined]
128
- and torch .Tag .view_copy not in op .tags # type: ignore[attr-defined]
129
- ):
130
- # NOTE(qihan): whether view_copy operators are marked as canonical is still under
131
- # discussion.
132
- raise SpecViolationError (
133
- "Operator {}.{} is not Aten Canonical." .format (
134
- op .__module__ , op .__name__
117
+ if isinstance (op , OpOverload ):
118
+ if (
119
+ torch .Tag .core not in op .tags # type: ignore[attr-defined]
120
+ and torch .Tag .view_copy not in op .tags # type: ignore[attr-defined]
121
+ ):
122
+ # NOTE(qihan): whether view_copy operators are marked as canonical is still under
123
+ # discussion.
124
+ raise SpecViolationError (
125
+ "Operator {}.{} is not Aten Canonical." .format (
126
+ op .__module__ , op .__name__
127
+ )
135
128
)
136
- )
137
129
138
- def check_valid (self , gm : GraphModule ) -> None :
130
+ def check_additional (self , gm : GraphModule ) -> None :
139
131
if self .check_edge_ops :
140
- self .check_valid_op = self .check_valid_edge_op
141
- super ().check_valid (gm )
142
132
_check_tensors_are_contiguous (gm )
143
133
_check_tensor_args_matching_op_allowed_dtype (gm )
144
- else :
145
- self .check_valid_op = self .check_valid_aten_op
146
134
147
135
# Additionally, edge dialect's operator must have same input dtype
148
136
for n in gm .graph .nodes :
0 commit comments