13
13
14
14
from __future__ import annotations
15
15
16
- import copy
17
16
import functools
18
- from typing import Any , Callable , Dict , List , Optional , Set
17
+ from typing import Any , Callable , Dict , List , Optional
19
18
20
19
import torch
21
- import torch .nn .functional as F
22
20
from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
23
21
24
22
from executorch .backends .arm .quantizer import arm_quantizer_utils
25
- from executorch .backends .arm .quantizer .arm_quantizer_utils import (
26
- mark_nodes_as_annotated ,
27
- propagate_annotation ,
28
- )
29
- from executorch .backends .arm .quantizer .quantization_annotation import (
30
- OP_TO_ANNOTATOR ,
31
- OperatorConfig ,
32
- OperatorPatternType ,
33
- )
23
+ from executorch .backends .arm .quantizer .arm_quantizer_utils import mark_node_as_annotated
24
+ from executorch .backends .arm .quantizer .quantization_annotator import annotate_graph
25
+
34
26
from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
35
27
from torch .ao .quantization .fake_quantize import (
36
28
FakeQuantize ,
58
50
]
59
51
60
52
61
- def _supported_symmetric_quantized_operators () -> Dict [str , List [OperatorPatternType ]]:
62
- supported_operators : Dict [str , List [OperatorPatternType ]] = {
63
- # Both conv and linear should be able to handle relu + hardtanh fusion since
64
- # those are clamp ops
65
- "conv2d" : [
66
- [torch .nn .Conv2d , torch .nn .ReLU ],
67
- [torch .nn .Conv2d , F .relu ],
68
- [F .conv2d , torch .nn .ReLU ],
69
- [F .conv2d , F .relu ],
70
- ],
71
- "linear" : [[torch .nn .Linear ], [F .linear ]],
72
- "add" : [[torch .add ]],
73
- "max_pool2d" : [[torch .nn .MaxPool2d ], [F .max_pool2d ]],
74
- "adaptive_avg_pool2d" : [
75
- [torch .nn .AdaptiveAvgPool2d ],
76
- [F .adaptive_avg_pool2d ],
77
- ],
78
- "mul" : [[torch .mul ]],
79
- "sub" : [[torch .sub ]],
80
- "min_max" : [[torch .min ], [torch .max ]],
81
- }
82
- return copy .deepcopy (supported_operators )
83
-
84
-
85
- def _get_supported_symmetric_config_and_operators () -> List [OperatorConfig ]:
86
- supported_config_and_operators : List [OperatorConfig ] = []
87
- for quantization_config in [
88
- get_symmetric_quantization_config (),
89
- get_symmetric_quantization_config (is_per_channel = True ),
90
- ]:
91
- ops = _supported_symmetric_quantized_operators ()
92
- for pattern_list in ops .values ():
93
- supported_config_and_operators .append (
94
- OperatorConfig (quantization_config , pattern_list )
95
- )
96
- return copy .deepcopy (supported_config_and_operators )
97
-
98
-
99
53
@functools .lru_cache
100
54
def get_symmetric_quantization_config (
101
55
is_per_channel : bool = False ,
@@ -180,10 +134,6 @@ def get_symmetric_quantization_config(
180
134
return quantization_config
181
135
182
136
183
- def _get_supported_config_and_operators () -> List [OperatorConfig ]:
184
- return _get_supported_symmetric_config_and_operators ()
185
-
186
-
187
137
NodeFilterType = Callable [[Node ], bool ]
188
138
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
189
139
a Node and returns whether the node should be annotated or not.
@@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
255
205
256
206
257
207
class ArmQuantizer (Quantizer ):
258
- supported_config_and_operators = _get_supported_config_and_operators ()
259
-
260
- # A list of supported static quantization annotators, in order of application.
261
- # For example, fusions come before singular ops.
262
- # The name must match the name used when registering the annotator.
263
- STATIC_ANNOTATION_ORDER = [
264
- "linear" ,
265
- "conv" ,
266
- "adaptive_avg_pool2d" ,
267
- "max_pool2d" ,
268
- "add" ,
269
- "sub" ,
270
- "mul" ,
271
- "min_max" ,
272
- "mm" ,
273
- "one_to_one" ,
274
- "generic" ,
275
- "upsample_nearest2d" ,
276
- ]
277
-
278
208
def __init__ (self ) -> None :
279
209
super ().__init__ ()
280
210
self .global_config : Optional [QuantizationConfig ] = None
@@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
331
261
The annotated model.
332
262
"""
333
263
model = self ._annotate_for_static_quantization_config (model )
334
- propagate_annotation (model )
335
264
return model
336
265
337
266
def _annotate_all_static_patterns (
@@ -353,8 +282,7 @@ def _annotate_all_static_patterns(
353
282
if quantization_config is None :
354
283
return model
355
284
356
- for op in self .STATIC_ANNOTATION_ORDER :
357
- OP_TO_ANNOTATOR [op ](model , quantization_config , filter_fn )
285
+ annotate_graph (model , quantization_config , filter_fn )
358
286
return model
359
287
360
288
def _annotate_for_static_quantization_config (
@@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config(
363
291
"""Matches the correct QuantizationConfig with the correct module using a filter
364
292
when running _annotate_all_static_patterns.
365
293
"""
294
+ if self .io_config :
295
+ self ._annotate_io (model , self .io_config )
296
+
366
297
module_name_list = list (self .module_name_config .keys ())
367
298
for module_name , config in self .module_name_config .items ():
368
299
self ._annotate_all_static_patterns (
@@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config(
381
312
_get_not_module_type_or_name_filter (tp_list , module_name_list ),
382
313
)
383
314
384
- if self .io_config :
385
- self ._annotate_io (model , self .io_config )
386
-
387
315
return model
388
316
389
317
def _annotate_io (
@@ -399,44 +327,13 @@ def _annotate_io(
399
327
node ,
400
328
quantization_config .get_output_act_qspec (),
401
329
)
402
- mark_nodes_as_annotated ([ node ] )
330
+ mark_node_as_annotated ( node )
403
331
if node .op == "output" :
404
332
parent = node .all_input_nodes [0 ]
405
333
_annotate_input_qspec_map (
406
334
node , parent , quantization_config .get_input_act_qspec ()
407
335
)
408
- mark_nodes_as_annotated ([ node ] )
336
+ mark_node_as_annotated ( node )
409
337
410
338
def validate (self , model : GraphModule ) -> None :
411
339
pass
412
-
413
- @classmethod
414
- def get_supported_operators (cls ) -> List [OperatorConfig ]:
415
- return cls .supported_config_and_operators
416
-
417
- @classmethod
418
- def get_supported_quantization_configs (cls ) -> List [QuantizationConfig ]:
419
- op_configs : Set [QuantizationConfig ] = set ({})
420
- for spec , _ in cls .supported_config_and_operators :
421
- op_configs .add (spec )
422
- return list (op_configs )
423
-
424
- @classmethod
425
- def get_supported_operator_for_quantization_config (
426
- cls , quantization_config : Optional [QuantizationConfig ]
427
- ) -> List [OperatorPatternType ]:
428
- if quantization_config is None :
429
- all_ops = []
430
- for _ , ops in cls .supported_config_and_operators :
431
- all_ops .extend (ops )
432
- return all_ops
433
-
434
- for config , ops in cls .supported_config_and_operators :
435
- # note: this assumes each entry in cls.supported_spec_and_operators
436
- # corresponds to one spec, e.g. we don't have
437
- # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
438
- # where the first and second entry have the same spec but did not
439
- # merge the op list
440
- if config == quantization_config :
441
- return ops
442
- return []
0 commit comments