Skip to content

Commit 379129d

Browse files
Arm backend: Add support for BN fusing during QAT (#10967)
Makes it possible to annotate patterns with more than two operators. This allows us to annotate patterns: conv -> bn and conv -> bn -> relu to be able to fold away BN after training in QAT. Also adds support for QAT in Tester class. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 56018e1 commit 379129d

File tree

3 files changed

+160
-25
lines changed

3 files changed

+160
-25
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13+
import torch.nn.functional as F
1314
from executorch.backends.arm.quantizer import QuantizationConfig
1415
from executorch.backends.arm.tosa_utils import get_node_debug_info
1516
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
@@ -142,29 +143,33 @@ def _match_pattern(
142143
143144
Each 'pattern' element is composed of a list of disjunctive nodes types.
144145
"""
145-
assert len(pattern) == 2, "Only two-nodes patterns supported currently"
146-
147-
if node.target in pattern[0]:
148-
assert len(node.users) != 0
149-
parent = node
150-
child = next(iter(node.users))
151-
elif node.target in pattern[1]:
152-
assert len(node.args) != 0
153-
parent = node.args[0] # type: ignore[assignment]
154-
child = node
155-
else:
156-
return False
157-
158-
if len(parent.users) != 1:
159-
return False
160-
161-
if parent.target not in pattern[0] or child.target not in pattern[1]:
162-
return False
163-
146+
assert len(pattern) > 0, "No pattern provided"
164147
if filter_fn is not None:
165-
return filter_fn(parent) and filter_fn(child)
166-
167-
return True
148+
if not filter_fn(node):
149+
return False
150+
if len(pattern) == 1:
151+
# Base case where it has passed the filter_fn. Simply look if node.target is in pattern.
152+
return node.target in pattern[0]
153+
if node.target not in [op for sub_pattern in pattern for op in sub_pattern]:
154+
# node.target not in pattern. No need to look at the rest of the pattern.
155+
return False
156+
# Find the index of this node's target in pattern
157+
idx = [node.target in sub_pattern for sub_pattern in pattern].index(True)
158+
left_pattern = pattern[:idx]
159+
# Exclude idx as this contains node.target which we have already matched
160+
right_pattern = pattern[idx + 1 :]
161+
left_condition = True
162+
right_condition = True
163+
# Recursively look at the rest of the pattern by calling this function for
164+
# node's input and user node with updated patterns.
165+
if len(left_pattern) > 0:
166+
parent = node.all_input_nodes[0]
167+
if len(parent.users) != 1:
168+
return False
169+
left_condition = _match_pattern(parent, left_pattern, filter_fn)
170+
if len(right_pattern) > 0:
171+
right_condition = _match_pattern(list(node.users)[0], right_pattern, filter_fn)
172+
return left_condition and right_condition
168173

169174

170175
_one_to_one = [
@@ -274,6 +279,58 @@ def any_or_hardtanh_min_zero(n: Node):
274279
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
275280

276281
if _match_pattern(
282+
node,
283+
[
284+
[
285+
torch.ops.aten.conv1d.default,
286+
torch.ops.aten.conv2d.default,
287+
torch.ops.aten.conv2d.padding,
288+
],
289+
[torch.ops.aten.batch_norm.default, F.batch_norm],
290+
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
291+
],
292+
filter_fn=any_or_hardtanh_min_zero,
293+
):
294+
if node.target in (
295+
torch.ops.aten.conv1d.default,
296+
torch.ops.aten.conv2d.default,
297+
torch.ops.aten.conv2d.padding,
298+
):
299+
quant_properties.quant_inputs = [
300+
_QuantProperty(0, input_act_qspec),
301+
_QuantProperty(1, weight_qspec, mark_annotated=True),
302+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
303+
]
304+
elif node.target in (
305+
torch.ops.aten.relu.default,
306+
torch.ops.aten.hardtanh.default,
307+
):
308+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
309+
310+
elif _match_pattern(
311+
node,
312+
[
313+
[
314+
torch.ops.aten.conv1d.default,
315+
torch.ops.aten.conv2d.default,
316+
torch.ops.aten.conv2d.padding,
317+
],
318+
[torch.ops.aten.batch_norm.default, F.batch_norm],
319+
],
320+
):
321+
if node.target in (
322+
torch.ops.aten.conv1d.default,
323+
torch.ops.aten.conv2d.default,
324+
torch.ops.aten.conv2d.padding,
325+
):
326+
quant_properties.quant_inputs = [
327+
_QuantProperty(0, input_act_qspec),
328+
_QuantProperty(1, weight_qspec, mark_annotated=True),
329+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
330+
]
331+
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
332+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
333+
elif _match_pattern(
277334
node,
278335
[
279336
[
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_quantization_config,
12+
TOSAQuantizer,
13+
)
14+
from executorch.backends.arm.test import common, conftest
15+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI
16+
17+
from executorch.backends.xnnpack.test.tester.tester import Quantize
18+
from torch import nn
19+
20+
21+
input_t1 = Tuple[torch.Tensor] # Input x
22+
23+
24+
class ConvModule(torch.nn.Module):
25+
input_shape = (1, 28, 28)
26+
batch_size = 64
27+
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
28+
29+
def __init__(self, batch_norm: bool = True) -> None:
30+
super().__init__()
31+
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
32+
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()
33+
34+
def forward(self, x: torch.Tensor):
35+
x = self.conv(x)
36+
x = self.bn(x)
37+
x = F.relu(x)
38+
39+
return x
40+
41+
42+
models = {
43+
"conv_bn_relu": ConvModule(batch_norm=True),
44+
"conv_relu": ConvModule(batch_norm=False),
45+
}
46+
47+
48+
@common.parametrize("model", models)
49+
def test_qat_tosa_BI(model: torch.nn.Module):
50+
pipeline = TosaPipelineBI[input_t1](model, model.test_data, [], [], qtol=1)
51+
tosa_version = conftest.get_option("tosa_version")
52+
tosa_profiles = {
53+
"0.80": common.TosaSpecification.create_from_string("TOSA-0.80+BI"),
54+
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
55+
}
56+
tosa_spec = tosa_profiles[tosa_version]
57+
quantizer = TOSAQuantizer(tosa_spec)
58+
pipeline.change_args(
59+
"quantize",
60+
Quantize(
61+
quantizer=quantizer,
62+
quantization_config=get_symmetric_quantization_config(is_qat=True),
63+
is_qat=True,
64+
),
65+
)
66+
pipeline.run()

backends/xnnpack/test/tester/tester.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,11 @@
5555
)
5656
from executorch.exir.program._program import _transform
5757
from torch._export.pass_base import PassType
58-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
58+
from torch.ao.quantization.quantize_pt2e import (
59+
convert_pt2e,
60+
prepare_pt2e,
61+
prepare_qat_pt2e,
62+
)
5963
from torch.ao.quantization.quantizer.quantizer import Quantizer
6064
from torch.export import export, ExportedProgram
6165
from torch.testing import FileCheck
@@ -150,26 +154,34 @@ def __init__(
150154
quantization_config: Optional[QuantizationConfig] = None,
151155
calibrate: bool = True,
152156
calibration_samples: Optional[Sequence[Any]] = None,
157+
is_qat: Optional[bool] = False,
153158
):
154159
self.quantizer = quantizer or XNNPACKQuantizer()
155160
self.quantization_config = (
156-
quantization_config or get_symmetric_quantization_config()
161+
quantization_config or get_symmetric_quantization_config(is_qat=is_qat)
157162
)
158163
self.calibrate = calibrate
159164
self.calibration_samples = calibration_samples
160165

161166
self.quantizer.set_global(self.quantization_config)
162167

163168
self.converted_graph = None
169+
self.is_qat = is_qat
164170

165171
def run(
166172
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
167173
) -> None:
168174
assert inputs is not None
175+
if self.is_qat:
176+
artifact.train()
169177
captured_graph = export_for_training(artifact, inputs, strict=True).module()
170178

171179
assert isinstance(captured_graph, torch.fx.GraphModule)
172-
prepared = prepare_pt2e(captured_graph, self.quantizer)
180+
181+
if self.is_qat:
182+
prepared = prepare_qat_pt2e(captured_graph, self.quantizer)
183+
else:
184+
prepared = prepare_pt2e(captured_graph, self.quantizer)
173185

174186
if self.calibrate:
175187
# Calibrate prepared model to provide data to quantization observers.

0 commit comments

Comments
 (0)