4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import itertools
8
+ import operator
7
9
from typing import List , Optional
8
10
9
11
import torch
10
- from executorch .exir .delegate import executorch_call_delegate
11
12
from executorch .exir .dialects .edge ._ops import EdgeOpOverload
12
13
from executorch .exir .error import ExportError , ExportErrorType
13
14
from executorch .exir .verification .arg_validator import (
16
17
)
17
18
18
19
from torch ._export .verifier import (
19
- _check_has_fake_tensor ,
20
- _check_tensors_are_contiguous ,
20
+ _check_val ,
21
21
ATenDialectVerifier ,
22
22
SpecViolationError ,
23
23
Verifier ,
28
28
29
29
30
30
ALLOWED_META_KEYS = {"spec" , "stack_trace" }
31
- VALID_BUILTIN_FUNCS = [
32
- executorch_call_delegate ,
33
- ]
34
31
35
32
36
- class EXIRATenDialectVerifier (ATenDialectVerifier ):
37
- def valid_builtin_funcs (self ):
38
- builtin_funcs = super ().valid_builtin_funcs ()
39
- builtin_funcs .extend (VALID_BUILTIN_FUNCS )
40
- return builtin_funcs
41
-
42
- # TODO(angelayi): Delete this function when we migrate all tests to
43
- # because right now old tracer does not add ["val"] metadata
44
- def check_valid (self , gm : GraphModule ) -> None : # noqa: C901
45
-
46
- for node in gm .graph .nodes :
47
- if node .op in {"call_module" , "call_method" }:
33
+ def _check_tensors_are_contiguous (gm : GraphModule ) -> None :
34
+ # Tensors be of contiguous format
35
+ for name , param in itertools .chain (gm .named_parameters (), gm .named_buffers ()):
36
+ if isinstance (param , torch .Tensor ):
37
+ if not param .is_contiguous ():
48
38
raise SpecViolationError (
49
- "call_module is not valid: got a class '{}' " . format ( node . target ),
39
+ f"Tensors in Aten dialect must be contiguous, { name } is not contiguous"
50
40
)
51
41
52
- if node .op == "call_function" :
53
- if node .target not in self .valid_builtin_funcs ():
54
- self .check_valid_op (node .target )
42
+
43
+ class EXIRATenDialectVerifier (ATenDialectVerifier ):
44
+ def _check_attribute (self , mod : torch .fx .GraphModule , target : str ) -> None :
45
+ # TODO: remove this once Executorch fully migrates to torch.export
46
+ pass
55
47
56
48
57
49
def _get_inputs (graph_module : GraphModule ) -> List [Optional [FakeTensor ]]:
@@ -97,15 +89,22 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
97
89
98
90
99
91
class EXIREdgeDialectVerifier (Verifier ):
100
- def __init__ (self , check_edge_ops : bool = False ) -> None :
92
+ def __init__ (self , check_edge_ops : bool = True ) -> None :
101
93
self .check_edge_ops = check_edge_ops
102
94
103
- def valid_builtin_funcs (self ):
104
- builtin_funcs = super ().valid_builtin_funcs ()
105
- builtin_funcs .extend (VALID_BUILTIN_FUNCS )
106
- return builtin_funcs
95
+ if self .check_edge_ops :
96
+ self .check_valid_op = self .check_valid_edge_op
97
+ else :
98
+ self .check_valid_op = self .check_valid_aten_op
99
+
100
+ def _check_attribute (self , mod : torch .fx .GraphModule , target : str ) -> None :
101
+ # TODO: remove this once Executorch fully migrates to torch.export
102
+ pass
107
103
108
104
def check_valid_edge_op (self , op ):
105
+ if op is operator .getitem :
106
+ return
107
+
109
108
if isinstance (op , OpOverload ) and not isinstance (op , EdgeOpOverload ):
110
109
raise SpecViolationError (
111
110
"Operator {}.{} is not an Edge operator." .format (
@@ -116,38 +115,28 @@ def check_valid_edge_op(self, op):
116
115
def check_valid_aten_op (self , op ) -> None :
117
116
super ().check_valid_op (op )
118
117
119
- op_name = op .name if hasattr (op , "name" ) else op .__name__
120
-
121
- if not isinstance (op , OpOverload ):
122
- raise SpecViolationError (
123
- "Operator '{}' is not a registered Op" .format (op_name ),
124
- )
125
-
126
- if (
127
- torch .Tag .core not in op .tags # type: ignore[attr-defined]
128
- and torch .Tag .view_copy not in op .tags # type: ignore[attr-defined]
129
- ):
130
- # NOTE(qihan): whether view_copy operators are marked as canonical is still under
131
- # discussion.
132
- raise SpecViolationError (
133
- "Operator {}.{} is not Aten Canonical." .format (
134
- op .__module__ , op .__name__
118
+ if isinstance (op , OpOverload ):
119
+ if (
120
+ torch .Tag .core not in op .tags # type: ignore[attr-defined]
121
+ and torch .Tag .view_copy not in op .tags # type: ignore[attr-defined]
122
+ ):
123
+ # NOTE(qihan): whether view_copy operators are marked as canonical is still under
124
+ # discussion.
125
+ raise SpecViolationError (
126
+ "Operator {}.{} is not Aten Canonical." .format (
127
+ op .__module__ , op .__name__
128
+ )
135
129
)
136
- )
137
130
138
- def check_valid (self , gm : GraphModule ) -> None :
131
+ def check_additional (self , gm : GraphModule ) -> None :
139
132
if self .check_edge_ops :
140
- self .check_valid_op = self .check_valid_edge_op
141
- super ().check_valid (gm )
142
133
_check_tensors_are_contiguous (gm )
143
134
_check_tensor_args_matching_op_allowed_dtype (gm )
144
- else :
145
- self .check_valid_op = self .check_valid_aten_op
146
135
147
136
# Additionally, edge dialect's operator must have same input dtype
148
137
for n in gm .graph .nodes :
149
138
if n .op == "call_function" and isinstance (n .target , OpOverload ):
150
- _check_has_fake_tensor (n )
139
+ _check_val (n )
151
140
dtypes = set ()
152
141
for arg in n .args :
153
142
if isinstance (arg , torch .Tensor ):
0 commit comments