Skip to content

Commit 4b02da3

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Use symmetric weights for convs and int8 in the default quantizer (#8344)
Summary: As titled. int8 should give better performance with Cadence kernels, since they're not improving uint8 anymore. The upcoming (quantized) convolution kernel needs symmetric weights, so we make that change as well. Reviewed By: zonglinpeng Differential Revision: D69405797
1 parent ee7d388 commit 4b02da3

File tree

4 files changed

+48
-94
lines changed

4 files changed

+48
-94
lines changed

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
# Example script for exporting simple models to flatbuffer
88

9+
#pyre-unsafe
10+
911
import logging
1012
import tempfile
1113

12-
import torch
13-
1414
from executorch.backends.cadence.aot.ops_registrations import * # noqa
1515
from typing import Any, Tuple
1616

@@ -23,38 +23,15 @@
2323
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
2424
from executorch.backends.cadence.runtime import runtime
2525
from executorch.backends.cadence.runtime.executor import BundledProgramManager
26-
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
27-
QuantizationConfig,
28-
QuantizationSpec,
29-
)
3026
from executorch.exir import ExecutorchProgramManager
3127
from torch import nn
32-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
3328

3429
from .utils import save_bpte_program, save_pte_program
3530

3631

3732
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3833
logging.basicConfig(level=logging.INFO, format=FORMAT)
3934

40-
act_qspec = QuantizationSpec(
41-
dtype=torch.int8,
42-
quant_min=-128,
43-
quant_max=127,
44-
qscheme=torch.per_tensor_affine,
45-
is_dynamic=False,
46-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
47-
)
48-
49-
wgt_qspec = QuantizationSpec(
50-
dtype=torch.int8,
51-
quant_min=-128,
52-
quant_max=127,
53-
qscheme=torch.per_tensor_affine,
54-
is_dynamic=False,
55-
observer_or_fake_quant_ctr=MinMaxObserver,
56-
)
57-
5835

5936
def export_model(
6037
model: nn.Module,
@@ -66,15 +43,8 @@ def export_model(
6643
working_dir = tempfile.mkdtemp(dir="/tmp")
6744
logging.debug(f"Created work directory {working_dir}")
6845

69-
qconfig = QuantizationConfig(
70-
act_qspec,
71-
act_qspec,
72-
wgt_qspec,
73-
None,
74-
)
75-
7646
# Instantiate the quantizer
77-
quantizer = CadenceDefaultQuantizer(qconfig)
47+
quantizer = CadenceDefaultQuantizer()
7848

7949
# Convert the model
8050
converted_model = convert_pt2(model, example_inputs, quantizer)

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def quantized_relu_per_tensor_meta(
576576
out_multiplier: int,
577577
out_shift: int,
578578
) -> torch.Tensor:
579-
return input.new_empty(input.size(), dtype=torch.uint8)
579+
return input.new_empty(input.size(), dtype=input.dtype)
580580

581581

582582
@register_fake("cadence::fully_connected")

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,46 @@
4040
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4141

4242

43-
act_qspec = QuantizationSpec(
44-
dtype=torch.uint8,
45-
quant_min=0,
46-
quant_max=255,
43+
act_qspec_asym8u = QuantizationSpec(
44+
dtype=torch.int8,
45+
quant_min=-128,
46+
quant_max=127,
4747
qscheme=torch.per_tensor_affine,
4848
is_dynamic=False,
4949
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5050
)
5151

52-
wgt_qspec = QuantizationSpec(
53-
dtype=torch.uint8,
54-
quant_min=0,
55-
quant_max=255,
52+
wgt_qspec_asym8u = QuantizationSpec(
53+
dtype=torch.int8,
54+
quant_min=-128,
55+
quant_max=127,
5656
qscheme=torch.per_tensor_affine,
5757
is_dynamic=False,
5858
observer_or_fake_quant_ctr=MinMaxObserver,
5959
)
6060

61+
wgt_qspec_asym8s = QuantizationSpec(
62+
dtype=torch.int8,
63+
quant_min=-128,
64+
quant_max=127,
65+
qscheme=torch.per_tensor_symmetric,
66+
is_dynamic=False,
67+
observer_or_fake_quant_ctr=MinMaxObserver,
68+
)
69+
6170
bias_qspec: Optional[QuantizationSpec] = None
6271

63-
_default_qconfig = QuantizationConfig(
64-
act_qspec,
65-
act_qspec,
66-
wgt_qspec,
72+
qconfig_A8uW8u = QuantizationConfig(
73+
act_qspec_asym8u,
74+
act_qspec_asym8u,
75+
wgt_qspec_asym8u,
76+
None,
77+
)
78+
79+
qconfig_A8uW8s = QuantizationConfig(
80+
act_qspec_asym8u,
81+
act_qspec_asym8u,
82+
wgt_qspec_asym8s,
6783
None,
6884
)
6985

@@ -147,19 +163,17 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
147163
return []
148164

149165

150-
def get_cadence_default_quantizer_list_with_config(
151-
quantization_config: QuantizationConfig,
152-
) -> List[Quantizer]:
166+
def get_cadence_default_quantizers() -> List[Quantizer]:
153167
return [
154-
CadenceAtenQuantizer(AddmmPattern(), quantization_config),
155-
CadenceAtenQuantizer(BmmPattern(), quantization_config),
156-
CadenceAtenQuantizer(Conv1dPattern(), quantization_config),
157-
CadenceAtenQuantizer(Conv2dPattern(), quantization_config),
158-
CadenceAtenQuantizer(LayerNormPattern(), quantization_config),
159-
CadenceAtenQuantizer(LinearPattern(), quantization_config),
160-
CadenceAtenQuantizer(MatmulPattern(), quantization_config),
161-
CadenceAtenQuantizer(ReluPattern0(), quantization_config),
162-
CadenceAtenQuantizer(ReluPattern1(), quantization_config),
168+
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8uW8u),
169+
CadenceAtenQuantizer(BmmPattern(), qconfig_A8uW8u),
170+
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8uW8s),
171+
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8uW8s),
172+
CadenceAtenQuantizer(LayerNormPattern(), qconfig_A8uW8u),
173+
CadenceAtenQuantizer(LinearPattern(), qconfig_A8uW8u),
174+
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8uW8u),
175+
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8uW8u),
176+
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8uW8u),
163177
]
164178

165179

@@ -178,10 +192,9 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
178192
Default quantizer for Cadence backend.
179193
"""
180194

181-
def __init__(self, qconfig: Optional[QuantizationConfig] = None) -> None:
182-
if qconfig is None:
183-
qconfig = _default_qconfig
184-
quantizers = get_cadence_default_quantizer_list_with_config(qconfig)
195+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
196+
if quantizers is None:
197+
quantizers = get_cadence_default_quantizers()
185198
super().__init__(quantizers)
186199

187200

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,6 @@ namespace impl {
1818
namespace HiFi {
1919
namespace native {
2020

21-
template <typename T>
22-
void quantized_relu_(
23-
const Tensor& input,
24-
const Tensor& in_zero_point,
25-
const int64_t out_zero_point,
26-
const Tensor& out_multiplier,
27-
const Tensor& out_shift,
28-
Tensor& output) {
29-
T q_zero_point = in_zero_point.const_data_ptr<T>()[0];
30-
const T* __restrict__ in = input.const_data_ptr<T>();
31-
T* __restrict__ out = output.mutable_data_ptr<T>();
32-
33-
const int32_t* __restrict__ out_multiplier_data =
34-
out_multiplier.const_data_ptr<int32_t>();
35-
const int32_t* __restrict__ out_shift_data =
36-
out_shift.const_data_ptr<int32_t>();
37-
38-
// Compute the out_scale from out_multiplier and out_shift
39-
const float out_scale =
40-
-out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]);
41-
42-
for (size_t i = 0, e = input.numel(); i < e; ++i) {
43-
float temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0;
44-
out[i] = kernels::quantize<T>(temp, out_scale, (int32_t)out_zero_point);
45-
}
46-
}
47-
4821
void quantized_relu_per_tensor_out(
4922
KernelRuntimeContext& ctx,
5023
const Tensor& input,
@@ -68,7 +41,7 @@ void quantized_relu_per_tensor_out(
6841
_out_multiplier,
6942
_out_shift,
7043
_out_zero_point,
71-
_out_zero_point,
44+
0,
7245
255,
7346
input.numel());
7447

@@ -85,7 +58,7 @@ void quantized_relu_per_tensor_out(
8558
_out_multiplier,
8659
_out_shift,
8760
_out_zero_point,
88-
_out_zero_point,
61+
-128,
8962
127,
9063
input.numel());
9164

@@ -107,9 +80,7 @@ void quantized_relu_per_tensor_out(
10780
const Tensor& out_multiplier,
10881
const Tensor& out_shift,
10982
Tensor& output) {
110-
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
111-
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
112-
uint8_t _in_zero_point = in_zero_point.const_data_ptr<uint8_t>()[0];
83+
int8_t _in_zero_point = in_zero_point.const_data_ptr<int8_t>()[0];
11384
int32_t _out_multiplier = out_multiplier.const_data_ptr<int32_t>()[0];
11485
int32_t _out_shift = out_shift.const_data_ptr<int32_t>()[0];
11586

0 commit comments

Comments
 (0)