5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
from collections import defaultdict
8
- from typing import Any , Dict , Sequence , Tuple
8
+ from typing import Any , Dict , Optional , Sequence , Tuple
9
9
10
10
import torch
11
11
from executorch .exir .dialects .edge ._ops import EdgeDialectFunctionSchema , EdgeOpOverload
@@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
37
37
38
38
def __init__ (self , graph_module : torch .fx .GraphModule ) -> None :
39
39
super ().__init__ (graph_module )
40
- self .violating_ops : Dict [EdgeOpOverload , Dict [ str , torch . dtype ]] = defaultdict (
41
- dict
42
- )
40
+ self .violating_ops : Dict [
41
+ EdgeOpOverload , Dict [ str , Optional [ torch . dtype ]]
42
+ ] = defaultdict ( dict )
43
43
44
44
def run_node (self , n : torch .fx .Node ) -> None :
45
45
self .node = n
@@ -52,6 +52,16 @@ def run_node(self, n: torch.fx.Node) -> None:
52
52
raise InternalError (str (e )) from e
53
53
return ret
54
54
55
+ def _get_kernel_arg (self , schema_arg , schema_arg_idx , args , kwargs ):
56
+ if schema_arg .name in kwargs :
57
+ kernel_arg = kwargs [schema_arg .name ]
58
+ elif not schema_arg .kwarg_only and schema_arg_idx < len (args ):
59
+ kernel_arg = args [schema_arg_idx ]
60
+ else :
61
+ kernel_arg = schema_arg .default_value
62
+
63
+ return kernel_arg
64
+
55
65
def call_function (
56
66
self , target : _Target , args : Tuple [_Argument , ...], kwargs : Dict [str , _Argument ]
57
67
) -> Any :
@@ -64,19 +74,32 @@ def call_function(
64
74
if isinstance (target , HigherOrderOperator ):
65
75
raise RunHigherOrderOperatorError ("Can't run delegate" )
66
76
return super ().call_function (target , args , kwargs )
67
- tensor_arg_types : Dict [str , torch .dtype ] = {}
77
+
78
+ # TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
79
+ tensor_arg_types : Dict [str , Optional [torch .dtype ]] = {}
68
80
for i , schema_arg in enumerate (target ._schema .arguments ):
69
- if not isinstance (schema_arg .type , torch .TensorType ):
70
- continue
71
- if schema_arg .name in kwargs :
72
- kernel_arg = kwargs [schema_arg .name ]
73
- elif not schema_arg .kwarg_only and i < len (args ):
74
- kernel_arg = args [i ]
75
- else :
76
- kernel_arg = schema_arg .default_value
77
- if not isinstance (kernel_arg , torch .Tensor ):
78
- continue
79
- tensor_arg_types [schema_arg .name ] = kernel_arg .dtype
81
+ if (
82
+ isinstance (schema_arg .type , torch .TensorType )
83
+ or schema_arg .type == torch .OptionalType .ofTensor ()
84
+ ):
85
+ kernel_arg = self ._get_kernel_arg (schema_arg , i , args , kwargs )
86
+ if not isinstance (kernel_arg , torch .Tensor ):
87
+ continue
88
+ tensor_arg_types [schema_arg .name ] = kernel_arg .dtype
89
+ elif schema_arg .type == torch .ListType .ofTensors ():
90
+ kernel_arg = self ._get_kernel_arg (schema_arg , i , args , kwargs )
91
+ if not isinstance (kernel_arg , list ) or not all (
92
+ isinstance (kernel_arg [i ], torch .Tensor )
93
+ for i in range (len (kernel_arg ))
94
+ ):
95
+ continue
96
+ if len (kernel_arg ):
97
+ tensor_arg_types [schema_arg .name ] = kernel_arg [0 ].dtype
98
+ else :
99
+ # If kernel_arg is an empty list, treat its type as None.
100
+ # FunctionDtypeConstraint.validate will take None as any legal dtype.
101
+ tensor_arg_types [schema_arg .name ] = None
102
+
80
103
ret_index = 0
81
104
kernel_rets = self .node .meta ["val" ]
82
105
ret_iter = iter (
@@ -85,11 +108,20 @@ def call_function(
85
108
for schema_ret in target ._schema .returns :
86
109
name = schema_ret .name if schema_ret .name else f"__ret_{ ret_index } "
87
110
kernel_ret = next (ret_iter )
111
+ # Return value should not be in OptionalTensor type, so only check torch.TensorType here.
88
112
if isinstance (schema_ret .type , torch .TensorType ) and isinstance (
89
113
kernel_ret , torch .Tensor
90
114
):
91
115
tensor_arg_types [name ] = kernel_ret .dtype
92
116
ret_index += 1
117
+ elif schema_ret .type == torch .ListType .ofTensors () and all (
118
+ isinstance (kernel_ret [i ], torch .Tensor ) for i in range (len (kernel_ret ))
119
+ ):
120
+ if len (kernel_ret ):
121
+ tensor_arg_types [name ] = kernel_ret [0 ].dtype
122
+ else :
123
+ tensor_arg_types [name ] = None
124
+ ret_index += 1
93
125
94
126
valid = target ._schema .dtype_constraint .validate (tensor_arg_types )
95
127
if not valid :
0 commit comments