@@ -29,7 +29,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:
29
29
@staticmethod
30
30
def _get_args_kwargs (
31
31
inputs : Dict [str , List [Union [BaseArg ]]],
32
- dtypes : Tuple [torch .dtype ],
32
+ dtypes : Tuple [Optional [ torch .dtype ] ],
33
33
mode : ArgMode ,
34
34
) -> Tuple [List [BaseArg ], Dict [str , BaseKwarg ]]:
35
35
"""Construct args and kwargs for op given dtypes."""
@@ -50,8 +50,11 @@ def _get_args_kwargs(
50
50
51
51
@staticmethod
52
52
def _produce_dtype_tuple (
53
- types : List [ArgType ], code_tuple : Tuple [int ], ty : ArgType , dt : torch .dtype
54
- ) -> Optional [Tuple [torch .dtype ]]:
53
+ types : List [ArgType ],
54
+ code_tuple : Tuple [int ],
55
+ ty : ArgType ,
56
+ dt : Optional [torch .dtype ],
57
+ ) -> Optional [Tuple [Optional [torch .dtype ]]]:
55
58
dtype_tuple = []
56
59
for i , code in enumerate (code_tuple ):
57
60
same_group = [dt ]
@@ -84,7 +87,7 @@ def _produce_dtype_tuple(
84
87
85
88
def _get_type_tuples (
86
89
self , inputs : Dict [str , List [BaseArg ]]
87
- ) -> List [List [torch .dtype ]]:
90
+ ) -> List [List [Optional [ torch .dtype ] ]]:
88
91
types = DtypeRunner ._get_types (inputs )
89
92
90
93
def mapping (t ):
@@ -102,7 +105,7 @@ def mapping(t):
102
105
103
106
def select_dtype_combinations (
104
107
self , inputs : Dict [str , List [BaseArg ]], genmode : GenMode
105
- ) -> Iterator [Tuple [torch .dtype ]]:
108
+ ) -> Iterator [Tuple [Optional [ torch .dtype ] ]]:
106
109
random .seed (0 )
107
110
108
111
def produce_code_tuples (n : int , i : int ) -> Iterator [Tuple [int ]]:
@@ -134,9 +137,11 @@ def run_dtypes(
134
137
self ,
135
138
name : str ,
136
139
inputs : Dict [str , List [BaseArg ]],
137
- dtypes : Tuple [torch .dtype ],
140
+ dtypes : Tuple [Optional [ torch .dtype ] ],
138
141
argmode : ArgMode = ArgMode .RANDOM ,
139
- ) -> Tuple [bool , str , Tuple [torch .dtype ], List [BaseArg ], Dict [str , BaseKwarg ]]:
142
+ ) -> Tuple [
143
+ bool , str , Tuple [Optional [torch .dtype ]], List [BaseArg ], Dict [str , BaseKwarg ]
144
+ ]:
140
145
args , kwargs = DtypeRunner ._get_args_kwargs (inputs , dtypes , argmode )
141
146
op = get_callable (name )
142
147
try :
@@ -169,7 +174,9 @@ def run_dtypes(
169
174
def run (
170
175
self , name : str , inputs : Dict [str , List [BaseArg ]]
171
176
) -> List [
172
- Tuple [bool , str , Tuple [torch .dtype ], List [BaseArg ], Dict [str , BaseKwarg ]]
177
+ Tuple [
178
+ bool , str , Tuple [Optional [torch .dtype ]], List [BaseArg ], Dict [str , BaseKwarg ]
179
+ ]
173
180
]:
174
181
results = []
175
182
type_tuples = self ._get_type_tuples (inputs )
0 commit comments