Skip to content

Commit 3681588

Browse files
authored
Use symmetric weights for convs and int8 in the default quantizer
Differential Revision: D69405797 Pull Request resolved: #8344
1 parent 95ef21d commit 3681588

File tree

4 files changed

+48
-94
lines changed

4 files changed

+48
-94
lines changed

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
# Example script for exporting simple models to flatbuffer
88

9+
#pyre-unsafe
10+
911
import logging
1012
import tempfile
1113

12-
import torch
13-
1414
from executorch.backends.cadence.aot.ops_registrations import * # noqa
1515
from typing import Any, Tuple
1616

@@ -23,38 +23,15 @@
2323
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2424
from executorch.backends.cadence.runtime import runtime
2525
from executorch.backends.cadence.runtime.executor import BundledProgramManager
26-
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
27-
QuantizationConfig,
28-
QuantizationSpec,
29-
)
3026
from executorch.exir import ExecutorchProgramManager
3127
from torch import nn
32-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
3328

3429
from .utils import save_bpte_program, save_pte_program
3530

3631

3732
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3833
logging.basicConfig(level=logging.INFO, format=FORMAT)
3934

40-
act_qspec = QuantizationSpec(
41-
dtype=torch.int8,
42-
quant_min=-128,
43-
quant_max=127,
44-
qscheme=torch.per_tensor_affine,
45-
is_dynamic=False,
46-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
47-
)
48-
49-
wgt_qspec = QuantizationSpec(
50-
dtype=torch.int8,
51-
quant_min=-128,
52-
quant_max=127,
53-
qscheme=torch.per_tensor_affine,
54-
is_dynamic=False,
55-
observer_or_fake_quant_ctr=MinMaxObserver,
56-
)
57-
5835

5936
def export_model(
6037
model: nn.Module,
@@ -66,15 +43,8 @@ def export_model(
6643
working_dir = tempfile.mkdtemp(dir="/tmp")
6744
logging.debug(f"Created work directory {working_dir}")
6845

69-
qconfig = QuantizationConfig(
70-
act_qspec,
71-
act_qspec,
72-
wgt_qspec,
73-
None,
74-
)
75-
7646
# Instantiate the quantizer
77-
quantizer = CadenceDefaultQuantizer(qconfig)
47+
quantizer = CadenceDefaultQuantizer()
7848

7949
# Convert the model
8050
converted_model = convert_pt2(model, example_inputs, quantizer)

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def quantized_relu_per_tensor_meta(
576576
out_multiplier: int,
577577
out_shift: int,
578578
) -> torch.Tensor:
579-
return input.new_empty(input.size(), dtype=torch.uint8)
579+
return input.new_empty(input.size(), dtype=input.dtype)
580580

581581

582582
@register_fake("cadence::fully_connected")

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,46 @@
4040
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4141

4242

43-
act_qspec = QuantizationSpec(
44-
dtype=torch.uint8,
45-
quant_min=0,
46-
quant_max=255,
43+
act_qspec_asym8u = QuantizationSpec(
44+
dtype=torch.int8,
45+
quant_min=-128,
46+
quant_max=127,
4747
qscheme=torch.per_tensor_affine,
4848
is_dynamic=False,
4949
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5050
)
5151

52-
wgt_qspec = QuantizationSpec(
53-
dtype=torch.uint8,
54-
quant_min=0,
55-
quant_max=255,
52+
wgt_qspec_asym8u = QuantizationSpec(
53+
dtype=torch.int8,
54+
quant_min=-128,
55+
quant_max=127,
5656
qscheme=torch.per_tensor_affine,
5757
is_dynamic=False,
5858
observer_or_fake_quant_ctr=MinMaxObserver,
5959
)
6060

61+
wgt_qspec_asym8s = QuantizationSpec(
62+
dtype=torch.int8,
63+
quant_min=-128,
64+
quant_max=127,
65+
qscheme=torch.per_tensor_symmetric,
66+
is_dynamic=False,
67+
observer_or_fake_quant_ctr=MinMaxObserver,
68+
)
69+
6170
bias_qspec: Optional[QuantizationSpec] = None
6271

63-
_default_qconfig = QuantizationConfig(
64-
act_qspec,
65-
act_qspec,
66-
wgt_qspec,
72+
qconfig_A8uW8u = QuantizationConfig(
73+
act_qspec_asym8u,
74+
act_qspec_asym8u,
75+
wgt_qspec_asym8u,
76+
None,
77+
)
78+
79+
qconfig_A8uW8s = QuantizationConfig(
80+
act_qspec_asym8u,
81+
act_qspec_asym8u,
82+
wgt_qspec_asym8s,
6783
None,
6884
)
6985

@@ -147,19 +163,17 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
147163
return []
148164

149165

150-
def get_cadence_default_quantizer_list_with_config(
151-
quantization_config: QuantizationConfig,
152-
) -> List[Quantizer]:
166+
def get_cadence_default_quantizers() -> List[Quantizer]:
153167
return [
154-
CadenceAtenQuantizer(AddmmPattern(), quantization_config),
155-
CadenceAtenQuantizer(BmmPattern(), quantization_config),
156-
CadenceAtenQuantizer(Conv1dPattern(), quantization_config),
157-
CadenceAtenQuantizer(Conv2dPattern(), quantization_config),
158-
CadenceAtenQuantizer(LayerNormPattern(), quantization_config),
159-
CadenceAtenQuantizer(LinearPattern(), quantization_config),
160-
CadenceAtenQuantizer(MatmulPattern(), quantization_config),
161-
CadenceAtenQuantizer(ReluPattern0(), quantization_config),
162-
CadenceAtenQuantizer(ReluPattern1(), quantization_config),
168+
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8uW8u),
169+
CadenceAtenQuantizer(BmmPattern(), qconfig_A8uW8u),
170+
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8uW8s),
171+
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8uW8s),
172+
CadenceAtenQuantizer(LayerNormPattern(), qconfig_A8uW8u),
173+
CadenceAtenQuantizer(LinearPattern(), qconfig_A8uW8u),
174+
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8uW8u),
175+
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8uW8u),
176+
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8uW8u),
163177
]
164178

165179

@@ -178,10 +192,9 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
178192
Default quantizer for Cadence backend.
179193
"""
180194

181-
def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None:
182-
if qconfig is None:
183-
qconfig = _default_qconfig
184-
quantizers = get_cadence_default_quantizer_list_with_config(qconfig)
195+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
196+
if quantizers is None:
197+
quantizers = get_cadence_default_quantizers()
185198
super().__init__(quantizers)
186199

187200

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,6 @@ namespace impl {
1818
namespace HiFi {
1919
namespace native {
2020

21-
template <typename T>
22-
void quantized_relu_(
23-
const Tensor& input,
24-
const Tensor& in_zero_point,
25-
const int64_t out_zero_point,
26-
const Tensor& out_multiplier,
27-
const Tensor& out_shift,
28-
Tensor& output) {
29-
T q_zero_point = in_zero_point.const_data_ptr<T>()[0];
30-
const T* __restrict__ in = input.const_data_ptr<T>();
31-
T* __restrict__ out = output.mutable_data_ptr<T>();
32-
33-
const int32_t* __restrict__ out_multiplier_data =
34-
out_multiplier.const_data_ptr<int32_t>();
35-
const int32_t* __restrict__ out_shift_data =
36-
out_shift.const_data_ptr<int32_t>();
37-
38-
// Compute the out_scale from out_multiplier and out_shift
39-
const float out_scale =
40-
-out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]);
41-
42-
for (size_t i = 0, e = input.numel(); i < e; ++i) {
43-
float temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0;
44-
out[i] = kernels::quantize<T>(temp, out_scale, (int32_t)out_zero_point);
45-
}
46-
}
47-
4821
void quantized_relu_per_tensor_out(
4922
KernelRuntimeContext& ctx,
5023
const Tensor& input,
@@ -68,7 +41,7 @@ void quantized_relu_per_tensor_out(
6841
_out_multiplier,
6942
_out_shift,
7043
_out_zero_point,
71-
_out_zero_point,
44+
0,
7245
255,
7346
input.numel());
7447

@@ -85,7 +58,7 @@ void quantized_relu_per_tensor_out(
8558
_out_multiplier,
8659
_out_shift,
8760
_out_zero_point,
88-
_out_zero_point,
61+
-128,
8962
127,
9063
input.numel());
9164

@@ -107,9 +80,7 @@ void quantized_relu_per_tensor_out(
10780
const Tensor& out_multiplier,
10881
const Tensor& out_shift,
10982
Tensor& output) {
110-
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
111-
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
112-
uint8_t _in_zero_point = in_zero_point.const_data_ptr<uint8_t>()[0];
83+
int8_t _in_zero_point = in_zero_point.const_data_ptr<int8_t>()[0];
11384
int32_t _out_multiplier = out_multiplier.const_data_ptr<int32_t>()[0];
11485
int32_t _out_shift = out_shift.const_data_ptr<int32_t>()[0];
11586

0 commit comments

Comments
 (0)