Skip to content

Commit 2dbb0c5

Browse files
committed
Increase calibration samples and tolerance for flaky quantized op tests
1 parent 06ef713 commit 2dbb0c5

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

backends/xnnpack/test/ops/test_add.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.xnnpack.test.tester import Tester
10+
from executorch.backends.xnnpack.test.tester import Quantize, Tester
1111

1212

1313
class TestAdd(unittest.TestCase):
@@ -136,9 +136,12 @@ def test_qs8_add2(self):
136136

137137
def test_qs8_add3(self):
138138
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
139+
calibration_samples = [
140+
(torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) for _ in range(100)
141+
]
139142
(
140143
Tester(self.Add(), inputs)
141-
.quantize()
144+
.quantize(Quantize(calibration_samples=calibration_samples))
142145
.export()
143146
.check_count({"torch.ops.aten.add.Tensor": 4})
144147
.check(["torch.ops.quantized_decomposed"])
@@ -152,7 +155,7 @@ def test_qs8_add3(self):
152155
)
153156
.to_executorch()
154157
.serialize()
155-
.run_method_and_compare_outputs()
158+
.run_method_and_compare_outputs(num_runs=10, atol=0.02, rtol=0.02)
156159
)
157160

158161
class AddRelu(torch.nn.Module):

backends/xnnpack/test/ops/test_conv1d.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1414
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
1515

16-
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
16+
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
1717
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
1818
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
1919

@@ -98,9 +98,17 @@ def _test_conv1d(
9898
stage=None,
9999
skip_to_executorch=False,
100100
):
101+
calibration_samples = (
102+
[tuple(torch.randn_like(inputs[i]) for i in range(len(inputs)))]
103+
if quantized
104+
else None
105+
)
106+
101107
tester = (
102108
(
103-
Tester(module, inputs, dynamic_shape).quantize()
109+
Tester(module, inputs, dynamic_shape).quantize(
110+
Quantize(calibration_samples=calibration_samples)
111+
)
104112
if quantized
105113
else Tester(module, inputs)
106114
)
@@ -114,7 +122,9 @@ def _test_conv1d(
114122
# For some tests we want to skip to_executorch because otherwise it will require the
115123
# quantized operators to be loaded and we don't want to do that in the test.
116124
if not skip_to_executorch:
117-
tester.to_executorch().serialize().run_method_and_compare_outputs()
125+
tester.to_executorch().serialize().run_method_and_compare_outputs(
126+
num_runs=10, atol=0.01, rtol=0.01
127+
)
118128

119129
def test_fp16_conv1d(self):
120130
inputs = (torch.randn(2, 2, 4).to(torch.float16),)

backends/xnnpack/test/tester/tester.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sys
1313
from abc import ABC, abstractmethod
1414
from collections import Counter, OrderedDict
15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
15+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
1616

1717
import torch
1818
from executorch.backends.xnnpack._passes import XNNPACKPassManager
@@ -146,12 +146,14 @@ def __init__(
146146
quantizer: Optional[Quantizer] = None,
147147
quantization_config: Optional[QuantizationConfig] = None,
148148
calibrate: bool = True,
149+
calibration_samples: Optional[Sequence[Any]] = None,
149150
):
150151
self.quantizer = quantizer or XNNPACKQuantizer()
151152
self.quantization_config = (
152153
quantization_config or get_symmetric_quantization_config()
153154
)
154155
self.calibrate = calibrate
156+
self.calibration_samples = calibration_samples
155157

156158
self.quantizer.set_global(self.quantization_config)
157159

@@ -168,7 +170,11 @@ def run(
168170

169171
if self.calibrate:
170172
# Calibrate prepared model to provide data to quantization observers.
171-
prepared(*inputs)
173+
if self.calibration_samples is not None:
174+
for inp in self.calibration_samples:
175+
prepared(*inp)
176+
else:
177+
prepared(*inputs)
172178

173179
converted = convert_pt2e(prepared)
174180
self.converted_graph = converted

0 commit comments

Comments
 (0)