@@ -100,22 +100,7 @@ def __init__(
100
100
self .ep = ep
101
101
assert len (self .constraints )
102
102
103
- def check_node_has_valid_dtype (self , node ):
104
- if node .target in {exir_ops .edge .aten .max_pool2d_with_indices .default }:
105
- return True
106
-
107
- valid_dtypes = {
108
- torch .float32 ,
109
- torch .int8 ,
110
- torch .qint8 ,
111
- }
112
- if (
113
- node .op != "placeholder"
114
- and node .op != "call_function"
115
- and node .op != "get_attr"
116
- ):
117
- return False
118
-
103
+ def _check_inputs_are_valid_dtypes (self , node , valid_dtypes ):
119
104
# Check inputs are valid dtypes
120
105
for arg in node .args :
121
106
if not isinstance (arg , torch .fx .Node ):
@@ -133,21 +118,46 @@ def check_node_has_valid_dtype(self, node):
133
118
if arg_val .dtype not in valid_dtypes :
134
119
return False
135
120
121
+ return True
122
+
123
+ def _check_outputs_are_valid_dtypes (self , node , valid_dtypes ):
136
124
# Check outputs are valid dtype
137
125
node_val = node .meta .get ("val" , None )
138
- if node_val is not None :
139
- if not isinstance (node_val , tuple ):
140
- node_val = (node_val ,)
126
+ if node_val is None :
127
+ return True
141
128
142
- for val in node_val :
143
- if not isinstance (val , torch .Tensor ):
144
- return False
129
+ if not isinstance (node_val , tuple ):
130
+ node_val = (node_val ,)
145
131
146
- if val .dtype not in valid_dtypes :
147
- return False
132
+ for val in node_val :
133
+ if not isinstance (val , torch .Tensor ):
134
+ return False
135
+
136
+ if val .dtype not in valid_dtypes :
137
+ return False
148
138
149
139
return True
150
140
141
+ def check_node_has_valid_dtype (self , node ):
142
+ if node .target in {exir_ops .edge .aten .max_pool2d_with_indices .default }:
143
+ return True
144
+
145
+ valid_dtypes = {
146
+ torch .float32 ,
147
+ torch .int8 ,
148
+ torch .qint8 ,
149
+ }
150
+ if (
151
+ node .op != "placeholder"
152
+ and node .op != "call_function"
153
+ and node .op != "get_attr"
154
+ ):
155
+ return False
156
+
157
+ return self ._check_inputs_are_valid_dtypes (
158
+ node , valid_dtypes
159
+ ) and self ._check_outputs_are_valid_dtypes (node , valid_dtypes )
160
+
151
161
def check_common_constraints (self , node ) -> bool :
152
162
has_valid_dtypes = self .check_node_has_valid_dtype (node )
153
163
0 commit comments