Skip to content

Commit 2587446

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Simplify check_node_has_valid_dtype (#1801)
Summary: Pull Request resolved: #1801 Linter seems to be unhappy. https://github.com/pytorch/executorch/actions/runs/7736952506/job/21095126521 Reviewed By: mcr229 Differential Revision: D53313246 fbshipit-source-id: 1aef8d38a6816db8292ec25e0c1bce5db0a545ef
1 parent 7ea3ebd commit 2587446

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,7 @@ def __init__(
100100
self.ep = ep
101101
assert len(self.constraints)
102102

103-
def check_node_has_valid_dtype(self, node):
104-
if node.target in {exir_ops.edge.aten.max_pool2d_with_indices.default}:
105-
return True
106-
107-
valid_dtypes = {
108-
torch.float32,
109-
torch.int8,
110-
torch.qint8,
111-
}
112-
if (
113-
node.op != "placeholder"
114-
and node.op != "call_function"
115-
and node.op != "get_attr"
116-
):
117-
return False
118-
103+
def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
119104
# Check inputs are valid dtypes
120105
for arg in node.args:
121106
if not isinstance(arg, torch.fx.Node):
@@ -133,21 +118,46 @@ def check_node_has_valid_dtype(self, node):
133118
if arg_val.dtype not in valid_dtypes:
134119
return False
135120

121+
return True
122+
123+
def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
136124
# Check outputs are valid dtype
137125
node_val = node.meta.get("val", None)
138-
if node_val is not None:
139-
if not isinstance(node_val, tuple):
140-
node_val = (node_val,)
126+
if node_val is None:
127+
return True
141128

142-
for val in node_val:
143-
if not isinstance(val, torch.Tensor):
144-
return False
129+
if not isinstance(node_val, tuple):
130+
node_val = (node_val,)
145131

146-
if val.dtype not in valid_dtypes:
147-
return False
132+
for val in node_val:
133+
if not isinstance(val, torch.Tensor):
134+
return False
135+
136+
if val.dtype not in valid_dtypes:
137+
return False
148138

149139
return True
150140

141+
def check_node_has_valid_dtype(self, node):
142+
if node.target in {exir_ops.edge.aten.max_pool2d_with_indices.default}:
143+
return True
144+
145+
valid_dtypes = {
146+
torch.float32,
147+
torch.int8,
148+
torch.qint8,
149+
}
150+
if (
151+
node.op != "placeholder"
152+
and node.op != "call_function"
153+
and node.op != "get_attr"
154+
):
155+
return False
156+
157+
return self._check_inputs_are_valid_dtypes(
158+
node, valid_dtypes
159+
) and self._check_outputs_are_valid_dtypes(node, valid_dtypes)
160+
151161
def check_common_constraints(self, node) -> bool:
152162
has_valid_dtypes = self.check_node_has_valid_dtype(node)
153163

0 commit comments

Comments
 (0)