Skip to content

Commit 2d39f78

Browse files
committed
Add test for fold qdq pass annotation
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I6154e13a5a6b75549862709d632ee6dd5c8b0e7f
1 parent bcbc4c6 commit 2d39f78

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
11+
FoldAndAnnotateQParamsPass,
12+
)
13+
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
17+
from executorch.backends.xnnpack.test.tester.tester import RunPasses
18+
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
21+
22+
class SimpleQuantizeModel(torch.nn.Module):
23+
def forward(self, x):
24+
return x + x
25+
26+
def get_inputs(self):
27+
return (torch.rand(1, 1280, 7, 7),)
28+
29+
30+
class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass):
31+
def __init__(self):
32+
super(FoldAndAnnotateQParamsPassTestClass, self).__init__(
33+
[exir_ops.edge.aten.add.Tensor]
34+
)
35+
36+
37+
class TestFoldAndAnnotateQParamsPass(unittest.TestCase):
38+
"""
39+
Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into
40+
the node and stores the quantization parameters in meta.
41+
"""
42+
43+
def test_fold_qdq_pass(self):
44+
"""
45+
Check that the pass runs for add operation and that one q node and one dq node
46+
is removed from the representation.
47+
"""
48+
module = SimpleQuantizeModel()
49+
test_pass_stage = RunPasses([FoldAndAnnotateQParamsPassTestClass])
50+
(
51+
ArmTester(
52+
module,
53+
example_inputs=module.get_inputs(),
54+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
55+
)
56+
.quantize()
57+
.export()
58+
.to_edge()
59+
.check_count(
60+
{
61+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
62+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
63+
}
64+
)
65+
.run_passes(test_pass_stage)
66+
.check_count(
67+
{
68+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
69+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1,
70+
}
71+
)
72+
)

0 commit comments

Comments
 (0)