Skip to content

Commit 3507412

Browse files
yifan_shen3facebook-github-bot
authored andcommitted
Add CoreMLQuantizer (#2338)
Summary: Add CoreMLQuantizer, and a test to demo it ``` # Given a pytorch model and a tuple of example inputs pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs) quantization_config = LinearQuantizerConfig.from_dict( { "global_config": { "quantization_scheme": QuantizationScheme.symmetric, "milestones": [0, 0, 10, 10], "activation_dtype": torch.quint8, "weight_dtype": torch.qint8, "weight_per_channel": True, } } ) quantizer = CoreMLQuantizer(quantization_config) # Use `prepare_pt2e` for post training quantization, `prepare_qat_pt2e` for quantization aware training prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) prepared_graph(*example_inputs) converted_graph = convert_pt2e(prepared_graph) ``` Pull Request resolved: #2338 Reviewed By: kirklandsign Differential Revision: D54781807 Pulled By: shoumikhin fbshipit-source-id: b33dd6552f01544b2dd0fdf00f2338e78499bfa5
1 parent 624ce59 commit 3507412

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright © 2024 Apple Inc. All rights reserved.
2+
#
3+
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4+
5+
from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright © 2024 Apple Inc. All rights reserved.
2+
#
3+
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4+
5+
import numpy as np
6+
import pytest
7+
from typing import Tuple
8+
9+
import torch
10+
from torch._export import capture_pre_autograd_graph
11+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
12+
13+
from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer
14+
15+
from coremltools.optimize.torch.quantization.quantization_config import (
16+
LinearQuantizerConfig,
17+
QuantizationScheme,
18+
)
19+
20+
21+
class TestCoreMLQuantizer:
22+
@staticmethod
23+
def quantize_and_compare(
24+
model,
25+
example_inputs: Tuple[torch.Tensor],
26+
quantization_type: str,
27+
) -> None:
28+
assert quantization_type in {"PTQ", "QAT"}
29+
30+
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
31+
32+
quantization_config = LinearQuantizerConfig.from_dict(
33+
{
34+
"global_config": {
35+
"quantization_scheme": QuantizationScheme.symmetric,
36+
"milestones": [0, 0, 10, 10],
37+
"activation_dtype": torch.quint8,
38+
"weight_dtype": torch.qint8,
39+
"weight_per_channel": True,
40+
}
41+
}
42+
)
43+
quantizer = CoreMLQuantizer(quantization_config)
44+
45+
if quantization_type == "PTQ":
46+
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
47+
elif quantization_type == "QAT":
48+
prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer)
49+
50+
prepared_graph(*example_inputs)
51+
converted_graph = convert_pt2e(prepared_graph)
52+
53+
model_output = model(*example_inputs).detach().numpy()
54+
quantized_output = converted_graph(*example_inputs).detach().numpy()
55+
np.testing.assert_allclose(quantized_output, model_output, rtol=5e-2, atol=5e-2)
56+
57+
@pytest.mark.parametrize("quantization_type", ("PTQ", "QAT"))
58+
def test_conv_relu(self, quantization_type):
59+
SHAPE = (1, 3, 256, 256)
60+
61+
class Model(torch.nn.Module):
62+
def __init__(self) -> None:
63+
super().__init__()
64+
self.conv = torch.nn.Conv2d(
65+
in_channels=3, out_channels=16, kernel_size=3, padding=1
66+
)
67+
self.relu = torch.nn.ReLU()
68+
69+
def forward(self, x: torch.Tensor) -> torch.Tensor:
70+
a = self.conv(x)
71+
return self.relu(a)
72+
73+
model = Model()
74+
75+
example_inputs = (torch.randn(SHAPE),)
76+
self.quantize_and_compare(
77+
model,
78+
example_inputs,
79+
quantization_type,
80+
)
81+
82+
@pytest.mark.parametrize("quantization_type", ("PTQ", "QAT"))
83+
def test_linear(self, quantization_type):
84+
SHAPE = (1, 5)
85+
86+
class Model(torch.nn.Module):
87+
def __init__(self) -> None:
88+
super().__init__()
89+
self.linear = torch.nn.Linear(5, 10)
90+
91+
def forward(self, x: torch.Tensor) -> torch.Tensor:
92+
return self.linear(x)
93+
94+
model = Model()
95+
96+
example_inputs = (torch.randn(SHAPE),)
97+
self.quantize_and_compare(
98+
model,
99+
example_inputs,
100+
quantization_type,
101+
)
102+
103+
104+
if __name__ == "__main__":
105+
test_runner = TestCoreMLQuantizer()
106+
test_runner.test_conv_relu("PTQ")
107+
test_runner.test_linear("QAT")

0 commit comments

Comments
 (0)