13
13
14
14
from __future__ import annotations
15
15
16
- import copy
17
16
import functools
18
- from typing import Any , Callable , Dict , List , Optional , Set
17
+ from typing import Any , Callable , Dict , List , Optional
19
18
20
19
import torch
21
- import torch .nn .functional as F
22
20
from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
23
21
24
22
from executorch .backends .arm .quantizer import arm_quantizer_utils
25
- from executorch .backends .arm .quantizer .arm_quantizer_utils import (
26
- mark_nodes_as_annotated ,
27
- propagate_annotation ,
28
- )
29
- from executorch .backends .arm .quantizer .quantization_annotation import (
30
- OP_TO_ANNOTATOR ,
31
- OperatorConfig ,
32
- OperatorPatternType ,
33
- )
23
+ from executorch .backends .arm .quantizer .arm_quantizer_utils import mark_node_as_annotated
24
+ from executorch .backends .arm .quantizer .quantization_annotator import annotate_graph
25
+
34
26
from executorch .backends .arm .quantizer .quantization_config import QuantizationConfig
35
27
from torch .ao .quantization .fake_quantize import (
36
28
FakeQuantize ,
58
50
]
59
51
60
52
61
- def _supported_symmetric_quantized_operators () -> Dict [str , List [OperatorPatternType ]]:
62
- supported_operators : Dict [str , List [OperatorPatternType ]] = {
63
- # Both conv and linear should be able to handle relu + hardtanh fusion since
64
- # those are clamp ops
65
- "conv2d" : [
66
- [torch .nn .Conv2d , torch .nn .ReLU ],
67
- [torch .nn .Conv2d , F .relu ],
68
- [F .conv2d , torch .nn .ReLU ],
69
- [F .conv2d , F .relu ],
70
- ],
71
- "linear" : [[torch .nn .Linear ], [F .linear ]],
72
- "add" : [[torch .add ]],
73
- "max_pool2d" : [[torch .nn .MaxPool2d ], [F .max_pool2d ]],
74
- "adaptive_avg_pool2d" : [
75
- [torch .nn .AdaptiveAvgPool2d ],
76
- [F .adaptive_avg_pool2d ],
77
- ],
78
- "mul" : [[torch .mul ]],
79
- "sub" : [[torch .sub ]],
80
- }
81
- return copy .deepcopy (supported_operators )
82
-
83
-
84
- def _get_supported_symmetric_config_and_operators () -> List [OperatorConfig ]:
85
- supported_config_and_operators : List [OperatorConfig ] = []
86
- for quantization_config in [
87
- get_symmetric_quantization_config (),
88
- get_symmetric_quantization_config (is_per_channel = True ),
89
- ]:
90
- ops = _supported_symmetric_quantized_operators ()
91
- for pattern_list in ops .values ():
92
- supported_config_and_operators .append (
93
- OperatorConfig (quantization_config , pattern_list )
94
- )
95
- return copy .deepcopy (supported_config_and_operators )
96
-
97
-
98
53
@functools .lru_cache
99
54
def get_symmetric_quantization_config (
100
55
is_per_channel : bool = False ,
@@ -179,10 +134,6 @@ def get_symmetric_quantization_config(
179
134
return quantization_config
180
135
181
136
182
- def _get_supported_config_and_operators () -> List [OperatorConfig ]:
183
- return _get_supported_symmetric_config_and_operators ()
184
-
185
-
186
137
NodeFilterType = Callable [[Node ], bool ]
187
138
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
188
139
a Node and returns whether the node should be annotated or not.
@@ -254,25 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
254
205
255
206
256
207
class ArmQuantizer (Quantizer ):
257
- supported_config_and_operators = _get_supported_config_and_operators ()
258
-
259
- # A list of supported static quantization annotators, in order of application.
260
- # For example, fusions come before singular ops.
261
- # The name must match the name used when registering the annotator.
262
- STATIC_ANNOTATION_ORDER = [
263
- "linear" ,
264
- "conv" ,
265
- "adaptive_avg_pool2d" ,
266
- "max_pool2d" ,
267
- "add" ,
268
- "sub" ,
269
- "mul" ,
270
- "mm" ,
271
- "one_to_one" ,
272
- "generic" ,
273
- "upsample_nearest2d" ,
274
- ]
275
-
276
208
def __init__ (self ) -> None :
277
209
super ().__init__ ()
278
210
self .global_config : Optional [QuantizationConfig ] = None
@@ -329,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule:
329
261
The annotated model.
330
262
"""
331
263
model = self ._annotate_for_static_quantization_config (model )
332
- propagate_annotation (model )
333
264
return model
334
265
335
266
def _annotate_all_static_patterns (
@@ -351,8 +282,7 @@ def _annotate_all_static_patterns(
351
282
if quantization_config is None :
352
283
return model
353
284
354
- for op in self .STATIC_ANNOTATION_ORDER :
355
- OP_TO_ANNOTATOR [op ](model , quantization_config , filter_fn )
285
+ annotate_graph (model , quantization_config , filter_fn )
356
286
return model
357
287
358
288
def _annotate_for_static_quantization_config (
@@ -361,6 +291,9 @@ def _annotate_for_static_quantization_config(
361
291
"""Matches the correct QuantizationConfig with the correct module using a filter
362
292
when running _annotate_all_static_patterns.
363
293
"""
294
+ if self .io_config :
295
+ self ._annotate_io (model , self .io_config )
296
+
364
297
module_name_list = list (self .module_name_config .keys ())
365
298
for module_name , config in self .module_name_config .items ():
366
299
self ._annotate_all_static_patterns (
@@ -379,9 +312,6 @@ def _annotate_for_static_quantization_config(
379
312
_get_not_module_type_or_name_filter (tp_list , module_name_list ),
380
313
)
381
314
382
- if self .io_config :
383
- self ._annotate_io (model , self .io_config )
384
-
385
315
return model
386
316
387
317
def _annotate_io (
@@ -397,44 +327,13 @@ def _annotate_io(
397
327
node ,
398
328
quantization_config .get_output_act_qspec (),
399
329
)
400
- mark_nodes_as_annotated ([ node ] )
330
+ mark_node_as_annotated ( node )
401
331
if node .op == "output" :
402
332
parent = node .all_input_nodes [0 ]
403
333
_annotate_input_qspec_map (
404
334
node , parent , quantization_config .get_input_act_qspec ()
405
335
)
406
- mark_nodes_as_annotated ([ node ] )
336
+ mark_node_as_annotated ( node )
407
337
408
338
def validate (self , model : GraphModule ) -> None :
409
339
pass
410
-
411
- @classmethod
412
- def get_supported_operators (cls ) -> List [OperatorConfig ]:
413
- return cls .supported_config_and_operators
414
-
415
- @classmethod
416
- def get_supported_quantization_configs (cls ) -> List [QuantizationConfig ]:
417
- op_configs : Set [QuantizationConfig ] = set ({})
418
- for spec , _ in cls .supported_config_and_operators :
419
- op_configs .add (spec )
420
- return list (op_configs )
421
-
422
- @classmethod
423
- def get_supported_operator_for_quantization_config (
424
- cls , quantization_config : Optional [QuantizationConfig ]
425
- ) -> List [OperatorPatternType ]:
426
- if quantization_config is None :
427
- all_ops = []
428
- for _ , ops in cls .supported_config_and_operators :
429
- all_ops .extend (ops )
430
- return all_ops
431
-
432
- for config , ops in cls .supported_config_and_operators :
433
- # note: this assumes each entry in cls.supported_spec_and_operators
434
- # corresponds to one spec, e.g. we don't have
435
- # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
436
- # where the first and second entry have the same spec but did not
437
- # merge the op list
438
- if config == quantization_config :
439
- return ops
440
- return []
0 commit comments