Skip to content

Commit 2590257

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix type hints for optional dtype (#347)
Summary: Pull Request resolved: #347 Reviewed By: larryliu0820 Differential Revision: D49202933 fbshipit-source-id: 856903ab0543c9c8bf9c03eac7c6381a57566dd1
1 parent 72abf1c commit 2590257

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

exir/dialects/edge/dtype/runner.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _get_types(inputs: Dict[str, List[BaseArg]]) -> List[ArgType]:
2929
@staticmethod
3030
def _get_args_kwargs(
3131
inputs: Dict[str, List[Union[BaseArg]]],
32-
dtypes: Tuple[torch.dtype],
32+
dtypes: Tuple[Optional[torch.dtype]],
3333
mode: ArgMode,
3434
) -> Tuple[List[BaseArg], Dict[str, BaseKwarg]]:
3535
"""Construct args and kwargs for op given dtypes."""
@@ -50,8 +50,11 @@ def _get_args_kwargs(
5050

5151
@staticmethod
5252
def _produce_dtype_tuple(
53-
types: List[ArgType], code_tuple: Tuple[int], ty: ArgType, dt: torch.dtype
54-
) -> Optional[Tuple[torch.dtype]]:
53+
types: List[ArgType],
54+
code_tuple: Tuple[int],
55+
ty: ArgType,
56+
dt: Optional[torch.dtype],
57+
) -> Optional[Tuple[Optional[torch.dtype]]]:
5558
dtype_tuple = []
5659
for i, code in enumerate(code_tuple):
5760
same_group = [dt]
@@ -84,7 +87,7 @@ def _produce_dtype_tuple(
8487

8588
def _get_type_tuples(
8689
self, inputs: Dict[str, List[BaseArg]]
87-
) -> List[List[torch.dtype]]:
90+
) -> List[List[Optional[torch.dtype]]]:
8891
types = DtypeRunner._get_types(inputs)
8992

9093
def mapping(t):
@@ -102,7 +105,7 @@ def mapping(t):
102105

103106
def select_dtype_combinations(
104107
self, inputs: Dict[str, List[BaseArg]], genmode: GenMode
105-
) -> Iterator[Tuple[torch.dtype]]:
108+
) -> Iterator[Tuple[Optional[torch.dtype]]]:
106109
random.seed(0)
107110

108111
def produce_code_tuples(n: int, i: int) -> Iterator[Tuple[int]]:
@@ -134,9 +137,11 @@ def run_dtypes(
134137
self,
135138
name: str,
136139
inputs: Dict[str, List[BaseArg]],
137-
dtypes: Tuple[torch.dtype],
140+
dtypes: Tuple[Optional[torch.dtype]],
138141
argmode: ArgMode = ArgMode.RANDOM,
139-
) -> Tuple[bool, str, Tuple[torch.dtype], List[BaseArg], Dict[str, BaseKwarg]]:
142+
) -> Tuple[
143+
bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
144+
]:
140145
args, kwargs = DtypeRunner._get_args_kwargs(inputs, dtypes, argmode)
141146
op = get_callable(name)
142147
try:
@@ -169,7 +174,9 @@ def run_dtypes(
169174
def run(
170175
self, name: str, inputs: Dict[str, List[BaseArg]]
171176
) -> List[
172-
Tuple[bool, str, Tuple[torch.dtype], List[BaseArg], Dict[str, BaseKwarg]]
177+
Tuple[
178+
bool, str, Tuple[Optional[torch.dtype]], List[BaseArg], Dict[str, BaseKwarg]
179+
]
173180
]:
174181
results = []
175182
type_tuples = self._get_type_tuples(inputs)

0 commit comments

Comments
 (0)