Skip to content

Commit b16271c

Browse files
authored
Refactor CadenceQuantizer
Differential Revision: D67645196 Pull Request resolved: #7540
1 parent 25a94ef commit b16271c

File tree

4 files changed

+50
-32
lines changed

4 files changed

+50
-32
lines changed

backends/cadence/aot/compiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
print_memory_planning_info,
1818
)
1919
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
20-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
20+
from executorch.backends.cadence.aot.quantizer.quantizer import (
21+
CadenceDefaultQuantizer,
22+
CadenceQuantizer,
23+
)
2124
from executorch.backends.cadence.aot.utils import (
2225
get_default_memory_config,
2326
MemoryConfig,
@@ -136,7 +139,7 @@ def quantize_pt2(
136139

137140
# Instantiate the quantizer to CadenceQuantizer if not supplied
138141
if not quantizer:
139-
quantizer = CadenceQuantizer()
142+
quantizer = CadenceDefaultQuantizer()
140143

141144
# Get converted graph module
142145
converted_gm = convert_pt2(model, inputs, quantizer)

backends/cadence/aot/export_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
fuse_pt2,
2121
)
2222

23-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
23+
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2424
from executorch.backends.cadence.runtime import runtime
2525
from executorch.backends.cadence.runtime.executor import BundledProgramManager
2626
from executorch.exir import ExecutorchProgramManager
@@ -74,7 +74,7 @@ def export_model(
7474
)
7575

7676
# Instantiate the quantizer
77-
quantizer = CadenceQuantizer(qconfig)
77+
quantizer = CadenceDefaultQuantizer(qconfig)
7878

7979
# Convert the model
8080
converted_model = convert_pt2(model, example_inputs, quantizer)

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@
6060

6161
bias_qspec: Optional[QuantizationSpec] = None
6262

63+
_default_qconfig = QuantizationConfig(
64+
act_qspec,
65+
act_qspec,
66+
wgt_qspec,
67+
None,
68+
)
69+
6370

6471
class CadenceAtenQuantizer(Quantizer):
6572
def __init__(
@@ -140,31 +147,39 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
140147
return []
141148

142149

150+
def get_cadence_default_quantizer_list_with_config(
151+
quantization_config: QuantizationConfig,
152+
) -> List[Quantizer]:
153+
return [
154+
CadenceAtenQuantizer(AddmmPattern(), quantization_config),
155+
CadenceAtenQuantizer(BmmPattern(), quantization_config),
156+
CadenceAtenQuantizer(Conv1dPattern(), quantization_config),
157+
CadenceAtenQuantizer(Conv2dPattern(), quantization_config),
158+
CadenceAtenQuantizer(LayerNormPattern(), quantization_config),
159+
CadenceAtenQuantizer(LinearPattern(), quantization_config),
160+
CadenceAtenQuantizer(MatmulPattern(), quantization_config),
161+
CadenceAtenQuantizer(ReluPattern0(), quantization_config),
162+
CadenceAtenQuantizer(ReluPattern1(), quantization_config),
163+
]
164+
165+
143166
class CadenceQuantizer(ComposableQuantizer):
144-
def __init__(
145-
self, quantization_config: Optional[QuantizationConfig] = None
146-
) -> None:
147-
static_qconfig = (
148-
QuantizationConfig(
149-
act_qspec,
150-
act_qspec,
151-
wgt_qspec,
152-
None,
153-
)
154-
if not quantization_config
155-
else quantization_config
156-
)
167+
"""
168+
Generic CadenceQuantizer. Although it can be used directly, it is typically a base
169+
class for explicitly defined quantizers (like CadenceDefaultQuantizer).
170+
"""
157171

158-
super().__init__(
159-
[
160-
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),
161-
CadenceAtenQuantizer(BmmPattern(), static_qconfig),
162-
CadenceAtenQuantizer(Conv1dPattern(), static_qconfig),
163-
CadenceAtenQuantizer(Conv2dPattern(), static_qconfig),
164-
CadenceAtenQuantizer(LayerNormPattern(), static_qconfig),
165-
CadenceAtenQuantizer(LinearPattern(), static_qconfig),
166-
CadenceAtenQuantizer(MatmulPattern(), static_qconfig),
167-
CadenceAtenQuantizer(ReluPattern0(), static_qconfig),
168-
CadenceAtenQuantizer(ReluPattern1(), static_qconfig),
169-
]
170-
)
172+
def __init__(self, quantizers: List[Quantizer]) -> None:
173+
super().__init__(quantizers)
174+
175+
176+
class CadenceDefaultQuantizer(CadenceQuantizer):
177+
"""
178+
Default quantizer for Cadence backend.
179+
"""
180+
181+
def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None:
182+
if qconfig is None:
183+
qconfig = _default_qconfig
184+
quantizers = get_cadence_default_quantizer_list_with_config(qconfig)
185+
super().__init__(quantizers)

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.cadence.aot.compiler import export_to_edge
1313

1414
from executorch.backends.cadence.aot.pass_utils import count_node
15-
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
15+
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
1616
from executorch.backends.cadence.aot.remove_ops import (
1717
RemoveAliasCopyOpPass,
1818
RemoveCloneOpPass,
@@ -465,7 +465,7 @@ def forward(self, x):
465465

466466
# Run the standard quant/convert steps, but without fusing
467467
# this leaves two redundant quant/dequant pairs to test with
468-
quantizer = CadenceQuantizer()
468+
quantizer = CadenceDefaultQuantizer()
469469
model_exp = export_for_training(M(), (inp,)).module()
470470
prepared_model = prepare_pt2e(model_exp, quantizer)
471471
prepared_model(inp)

0 commit comments

Comments
 (0)