Skip to content

Commit 1d5bc33

Browse files
committed
Add test of meandim to averagepool pass
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I363d9df24c9c5c1c507a2a3a40358527e65d874f
1 parent d943bba commit 1d5bc33

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

backends/arm/passes/meandim_to_averagepool_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def call_operator(
3131

3232
input_value = cast(ProxyValue, args[0])
3333
dim = cast(list, args[1])
34-
keep_dim = cast(bool, args[2])
34+
keep_dim = cast(bool, args[2]) if len(args) > 2 else False
3535

3636
# averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True
3737
# so check the dim argument for this case
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
11+
ConvertMeanDimToAveragePool,
12+
)
13+
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
17+
from executorch.backends.xnnpack.test.tester.tester import RunPasses
18+
19+
20+
class MeanDim(torch.nn.Module):
21+
def forward(self, x):
22+
return torch.mean(x, dim=[-1, -2], keepdim=True)
23+
24+
def get_inputs(self):
25+
return (torch.rand(1, 1280, 7, 7),)
26+
27+
28+
class MeanDim2(torch.nn.Module):
29+
def forward(self, x):
30+
return torch.mean(x, dim=1)
31+
32+
def get_inputs(self):
33+
return (torch.rand(1, 1280, 7, 7),)
34+
35+
36+
class TestMeandimToAveragePool2dPass(unittest.TestCase):
37+
"""
38+
Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d
39+
for the special case where dim is [-1, -2] and keepdim is True.
40+
"""
41+
42+
def test_tosa_BI_meandim_to_averagepool(self):
43+
module = MeanDim()
44+
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
45+
(
46+
ArmTester(
47+
module,
48+
example_inputs=module.get_inputs(),
49+
compile_spec=common.get_tosa_compile_spec(),
50+
)
51+
.quantize()
52+
.export()
53+
.to_edge()
54+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
55+
.run_passes(test_pass_stage)
56+
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
57+
)
58+
59+
def test_tosa_BI_meandim_no_modification(self):
60+
module = MeanDim2()
61+
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
62+
(
63+
ArmTester(
64+
module,
65+
example_inputs=module.get_inputs(),
66+
compile_spec=common.get_tosa_compile_spec(),
67+
)
68+
.quantize()
69+
.export()
70+
.to_edge()
71+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
72+
.run_passes(test_pass_stage)
73+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
74+
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
75+
)

0 commit comments

Comments
 (0)