Skip to content

Commit 8101bf1

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Use int8 quantizer in the OSS flow (#6166)
Summary: Pull Request resolved: #6166 As titled. This change add the ability to supply a qconfig to the `CadenceQuantizer`, and uses `int8` instead of `uint8` in `export_model` as per Cadence request. Going forward, `int8` should be the primary 8-bit type. Reviewed By: dulinriley Differential Revision: D64209639 fbshipit-source-id: c8bb385aa75cdeb0cfb92217f4cdc5335a10a3b9
1 parent ce67b54 commit 8101bf1

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

backends/cadence/aot/export_example.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import logging
1010
import tempfile
1111

12+
import torch
13+
1214
from executorch.backends.cadence.aot.ops_registrations import * # noqa
1315
from typing import Any, Tuple
1416

@@ -17,18 +19,42 @@
1719
export_to_cadence_edge_executorch,
1820
fuse_pt2,
1921
)
22+
2023
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
2124
from executorch.backends.cadence.runtime import runtime
2225
from executorch.backends.cadence.runtime.executor import BundledProgramManager
2326
from executorch.exir import ExecutorchProgramManager
2427
from torch import nn
28+
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
29+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
30+
QuantizationConfig,
31+
QuantizationSpec,
32+
)
2533

2634
from .utils import save_bpte_program, save_pte_program
2735

2836

2937
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3038
logging.basicConfig(level=logging.INFO, format=FORMAT)
3139

40+
act_qspec = QuantizationSpec(
41+
dtype=torch.int8,
42+
quant_min=-128,
43+
quant_max=127,
44+
qscheme=torch.per_tensor_affine,
45+
is_dynamic=False,
46+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
47+
)
48+
49+
wgt_qspec = QuantizationSpec(
50+
dtype=torch.int8,
51+
quant_min=-128,
52+
quant_max=127,
53+
qscheme=torch.per_tensor_affine,
54+
is_dynamic=False,
55+
observer_or_fake_quant_ctr=MinMaxObserver,
56+
)
57+
3258

3359
def export_model(
3460
model: nn.Module,
@@ -39,8 +65,15 @@ def export_model(
3965
working_dir = tempfile.mkdtemp(dir="/tmp")
4066
logging.debug(f"Created work directory {working_dir}")
4167

68+
qconfig = QuantizationConfig(
69+
act_qspec,
70+
act_qspec,
71+
wgt_qspec,
72+
None,
73+
)
74+
4275
# Instantiate the quantizer
43-
quantizer = CadenceQuantizer()
76+
quantizer = CadenceQuantizer(qconfig)
4477

4578
# Convert the model
4679
converted_model = convert_pt2(model, example_inputs, quantizer)

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
141141

142142

143143
class CadenceQuantizer(ComposableQuantizer):
144-
def __init__(self) -> None:
145-
static_qconfig = QuantizationConfig(
146-
act_qspec,
147-
act_qspec,
148-
wgt_qspec,
149-
None,
144+
def __init__(
145+
self, quantization_config: Optional[QuantizationConfig] = None
146+
) -> None:
147+
static_qconfig = (
148+
QuantizationConfig(
149+
act_qspec,
150+
act_qspec,
151+
wgt_qspec,
152+
None,
153+
)
154+
if not quantization_config
155+
else quantization_config
150156
)
157+
151158
super().__init__(
152159
[
153160
CadenceAtenQuantizer(AddmmPattern(), static_qconfig),

0 commit comments

Comments
 (0)