5
5
6
6
# pyre-unsafe
7
7
8
+ import itertools
8
9
import operator
10
+ import typing
9
11
from typing import final , Optional , Sequence , Type
10
12
13
+ import torch
14
+
11
15
import torch .fx as fx
16
+ from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
17
+ from executorch .backends .arm ._passes .fuse_quantized_activation_pass import (
18
+ FuseQuantizedActivationPass ,
19
+ )
12
20
from executorch .backends .arm .tosa_specification import TosaSpecification
13
21
from executorch .exir .dialects ._ops import ops as exir_ops
14
22
from torch .fx .passes .operator_support import any_chain , chain , OperatorSupportBase
23
+ from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
15
24
16
25
17
26
class SupportedTOSAOperatorCheck (OperatorSupportBase ):
@@ -27,7 +36,9 @@ def __init__(self, tosa_spec: TosaSpecification):
27
36
targets : list [str ] = []
28
37
29
38
@final
30
- def is_node_supported (self , submodules , node : fx .Node ) -> bool :
39
+ def is_node_supported (
40
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
41
+ ) -> bool :
31
42
if node .target not in self .targets :
32
43
return False
33
44
return self .is_node_tosa_supported (node , self .tosa_spec )
@@ -75,6 +86,10 @@ def tosa_support_factory(
75
86
tosa_spec : TosaSpecification ,
76
87
additional_checks : Optional [Sequence [OperatorSupportBase ]] = None ,
77
88
) -> OperatorSupportBase :
89
+ negative_checks : list [OperatorSupportBase ] = []
90
+ if not tosa_spec .support_float ():
91
+ negative_checks .append (NeedsDecompositionCheck ())
92
+ negative_checks .append (CheckProperQuantization ())
78
93
return chain (
79
94
any_chain (
80
95
BaseTOSASupportList (),
@@ -83,13 +98,16 @@ def tosa_support_factory(
83
98
for check in get_registered_tosa_support_checks (tosa_spec )
84
99
),
85
100
),
101
+ * negative_checks ,
86
102
* additional_checks if additional_checks else [],
87
103
)
88
104
89
105
90
106
class BaseTOSASupportList (OperatorSupportBase ):
91
107
92
- def is_node_supported (self , submodules , node : fx .Node ) -> bool :
108
+ def is_node_supported (
109
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
110
+ ) -> bool :
93
111
supported = node .op == "call_function" and node .target in [
94
112
exir_ops .edge .aten .abs .default ,
95
113
exir_ops .edge .aten .add .Tensor ,
@@ -150,3 +168,154 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
150
168
]
151
169
152
170
return supported
171
+
172
+
173
+ class NeedsDecompositionCheck (OperatorSupportBase ):
174
+ """
175
+ Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
176
+ the operator, and to get optimal quantization parameters for each operator. This check will reject operators
177
+ that need to be decomposed.
178
+ """
179
+
180
+ def is_node_supported (
181
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
182
+ ) -> bool :
183
+
184
+ if node .op != "call_function" :
185
+ return True
186
+ if node .target == exir_ops .edge .aten .mean .dim :
187
+ dim = node .args [1 ]
188
+ return dim == [- 1 , - 2 ]
189
+ needs_decomp = node .target in [
190
+ exir_ops .edge .aten .div .Tensor ,
191
+ exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
192
+ exir_ops .edge .aten .native_layer_norm .default ,
193
+ exir_ops .edge .aten .mean .dim ,
194
+ exir_ops .edge .aten ._softmax .default ,
195
+ exir_ops .edge .aten ._log_softmax .default ,
196
+ exir_ops .edge .aten .var .correction ,
197
+ exir_ops .edge .aten .var .dim ,
198
+ ]
199
+ return not needs_decomp
200
+
201
+
202
+ class CheckProperQuantization (OperatorSupportBase ):
203
+ """
204
+ For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
205
+ and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale
206
+ activations.
207
+ """
208
+
209
+ dq_op = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
210
+ q_op = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
211
+
212
+ def _is_matmul_node_supported (
213
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
214
+ ):
215
+ """
216
+ Find the matmul source partition containing this node and check that all its inputs and outputs are quantized.
217
+ """
218
+ for graph_module in submodules .values ():
219
+ graph_module = typing .cast (fx .GraphModule , graph_module )
220
+ matmul_partitions = get_source_partitions (
221
+ graph_module .graph ,
222
+ [
223
+ torch .matmul ,
224
+ ],
225
+ None ,
226
+ )
227
+ matmul_partitions = list (
228
+ itertools .chain .from_iterable (matmul_partitions .values ())
229
+ )
230
+ matched_partition = None
231
+ for partition in matmul_partitions :
232
+ if node in partition .nodes :
233
+ matched_partition = partition
234
+ if matched_partition is not None :
235
+ input_quantized = all (
236
+ input_node .target == self .dq_op
237
+ for input_node in matched_partition .input_nodes
238
+ )
239
+ if not input_quantized :
240
+ return False
241
+ output_quantized = all (
242
+ output_node_user .target == self .q_op
243
+ for output_node_user in matched_partition .output_nodes [0 ].users
244
+ )
245
+ if not output_quantized :
246
+ return False
247
+ else :
248
+ return False
249
+
250
+ return True
251
+
252
+ def is_node_supported (
253
+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
254
+ ) -> bool :
255
+ output_quantized = False
256
+ input_quantized = False
257
+ if node .target not in (
258
+ exir_ops .edge .aten .add .Tensor ,
259
+ exir_ops .edge .aten .avg_pool2d .default ,
260
+ exir_ops .edge .aten .bmm .default ,
261
+ exir_ops .edge .aten .convolution .default ,
262
+ exir_ops .edge .aten .exp .default ,
263
+ exir_ops .edge .aten .hardtanh .default ,
264
+ exir_ops .edge .aten .linear .default ,
265
+ exir_ops .edge .aten .log .default ,
266
+ exir_ops .edge .aten .max_pool2d_with_indices .default ,
267
+ exir_ops .edge .aten .mm .default ,
268
+ exir_ops .edge .aten .mul .Tensor ,
269
+ exir_ops .edge .aten .reciprocal .default ,
270
+ exir_ops .edge .aten .relu .default ,
271
+ exir_ops .edge .aten .rsqrt .default ,
272
+ exir_ops .edge .aten .sigmoid .default ,
273
+ exir_ops .edge .aten .sub .Tensor ,
274
+ exir_ops .edge .aten .tanh .default ,
275
+ exir_ops .edge .aten .upsample_nearest2d .vec ,
276
+ ):
277
+ return True
278
+ elif node .target in (
279
+ exir_ops .edge .aten .bmm .default ,
280
+ exir_ops .edge .aten .mm .default ,
281
+ ):
282
+ source_fn_stack : tuple [typing .Any ] = node .meta .get ("source_fn_stack" , [])
283
+ if len (source_fn_stack ) > 0 :
284
+ if source_fn_stack [- 1 ][1 ] in (torch .matmul ,):
285
+ return self ._is_matmul_node_supported (submodules , node )
286
+
287
+ elif node .target in (exir_ops .edge .aten .max_pool2d_with_indices .default ,):
288
+ users = node .users
289
+ output_quantized = all (
290
+ user .target == operator .getitem
291
+ and all (user_user .target == self .q_op for user_user in user .users )
292
+ for user in users
293
+ )
294
+ elif FuseQuantizedActivationPass ._is_fuseable_input (node ):
295
+ users = node .users
296
+ output_quantized = all (
297
+ FuseQuantizedActivationPass ._is_fuseable_quantized_activation (user )
298
+ for user in users
299
+ )
300
+ elif FuseQuantizedActivationPass ._is_fuseable_quantized_activation (node ):
301
+ input_node = node .all_input_nodes [0 ]
302
+ input_quantized = FuseQuantizedActivationPass ._is_fuseable_input (input_node )
303
+
304
+ input_quantized = input_quantized or all (
305
+ (input_node .target == self .dq_op )
306
+ or (not get_first_fake_tensor (input_node ).dtype .is_floating_point )
307
+ for input_node in node .all_input_nodes
308
+ )
309
+
310
+ if not input_quantized :
311
+ return False
312
+
313
+ output_quantized = output_quantized or all (
314
+ (output_node .target == self .q_op )
315
+ or (not get_first_fake_tensor (output_node ).dtype .is_floating_point )
316
+ for output_node in node .users
317
+ )
318
+
319
+ if not output_quantized :
320
+ return False
321
+ return True
0 commit comments