Skip to content

Commit 63e794a

Browse files
authored
Add pass to convert special case of mean.dim to averagepool2d
Differential Revision: D62034655 Pull Request resolved: #4900
1 parent 67ae762 commit 63e794a

File tree

5 files changed

+137
-27
lines changed

5 files changed

+137
-27
lines changed

backends/arm/operators/op_mean_dim.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common
1514

1615

1716
@register_node_visitor
@@ -30,29 +29,4 @@ def define_node(
3029
is_quant_node: bool,
3130
) -> None:
3231

33-
input_tensor = inputs[0]
34-
dim = node.args[1]
35-
keep_dim = node.args[2]
36-
37-
# mean.dim(-1, -2) is the same as avg_pool2d when just computing mean over HW dimensions.
38-
# Since tosa doesn't have mean.dim operation, lowers it to average pooling instead.
39-
if dim == [-1, -2]:
40-
if keep_dim is True:
41-
# Given the shape format of input is (N, C, H, W)
42-
kernel_size = [input_tensor.shape[2], input_tensor.shape[3]]
43-
stride = [1, 1]
44-
padding = [0, 0, 0, 0]
45-
46-
build_avg_pool_2d_common(
47-
node,
48-
tosa_graph,
49-
input_tensor,
50-
kernel_size,
51-
stride,
52-
padding,
53-
is_quant_node,
54-
output,
55-
)
56-
return
57-
5832
raise AssertionError("unsupported")

backends/arm/passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from executorch.backends.arm.passes.convert_split_to_slice import (
1616
ConvertSplitToSlicePass,
1717
)
18+
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
19+
ConvertMeanDimToAveragePool,
20+
)
1821
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
1922
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
2023
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -33,6 +36,7 @@ def transform_to_backend_pipeline(
3336
self.add_pass(SizeAdjustConv2DPass())
3437
self.add_pass(RemoveClonePass())
3538
self.add_pass(ConvertExpandCopyToRepeatPass())
39+
self.add_pass(ConvertMeanDimToAveragePool())
3640
self.add_pass(ConvertSplitToSlicePass())
3741
for spec in compile_spec:
3842
if spec.key == "permute_memory_format":
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, cast, Dict, Tuple
8+
9+
import torch.fx
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
13+
14+
Argument = Any
15+
16+
17+
class ConvertMeanDimToAveragePool(ExportPass):
18+
"""
19+
Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
20+
"""
21+
22+
def call_operator(
23+
self,
24+
op: torch.fx.node.Target,
25+
args: Tuple[Argument, ...],
26+
kwargs: Dict[str, Argument],
27+
meta: NodeMetadata,
28+
) -> ProxyValue:
29+
if op != exir_ops.edge.aten.mean.dim:
30+
return super().call_operator(op, args, kwargs, meta)
31+
32+
input_value = cast(ProxyValue, args[0])
33+
dim = cast(list, args[1])
34+
keep_dim = cast(bool, args[2]) if len(args) > 2 else False
35+
36+
# averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True
37+
# so check the dim argument for this case
38+
if dim == [-1, -2] and keep_dim is True:
39+
# Given the shape format of input is (N, C, H, W)
40+
kernel_size = [
41+
input_value.to_tensor().size()[2],
42+
input_value.to_tensor().size()[3],
43+
]
44+
stride = [1, 1]
45+
return super().call_operator(
46+
exir_ops.edge.aten.avg_pool2d.default,
47+
(input_value, kernel_size, stride),
48+
{},
49+
meta,
50+
)
51+
else:
52+
return super().call_operator(op, args, kwargs, meta)

backends/arm/test/ops/test_mean_dim.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ def _test_meandim_tosa_u55_BI_pipeline(
106106
.check(["torch.ops.quantized_decomposed"])
107107
.to_edge()
108108
.partition()
109-
.check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
109+
.check_not(
110+
[
111+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
112+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
113+
]
114+
)
110115
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
111116
.to_executorch()
112117
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
11+
ConvertMeanDimToAveragePool,
12+
)
13+
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
17+
from executorch.backends.xnnpack.test.tester.tester import RunPasses
18+
19+
20+
class MeanDim(torch.nn.Module):
21+
def forward(self, x):
22+
return torch.mean(x, dim=[-1, -2], keepdim=True)
23+
24+
def get_inputs(self):
25+
return (torch.rand(1, 1280, 7, 7),)
26+
27+
28+
class MeanDim2(torch.nn.Module):
29+
def forward(self, x):
30+
return torch.mean(x, dim=1)
31+
32+
def get_inputs(self):
33+
return (torch.rand(1, 1280, 7, 7),)
34+
35+
36+
class TestMeandimToAveragePool2dPass(unittest.TestCase):
37+
"""
38+
Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d
39+
for the special case where dim is [-1, -2] and keepdim is True.
40+
"""
41+
42+
def test_tosa_BI_meandim_to_averagepool(self):
43+
module = MeanDim()
44+
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
45+
(
46+
ArmTester(
47+
module,
48+
example_inputs=module.get_inputs(),
49+
compile_spec=common.get_tosa_compile_spec(),
50+
)
51+
.quantize()
52+
.export()
53+
.to_edge()
54+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
55+
.run_passes(test_pass_stage)
56+
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
57+
)
58+
59+
def test_tosa_BI_meandim_no_modification(self):
60+
module = MeanDim2()
61+
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
62+
(
63+
ArmTester(
64+
module,
65+
example_inputs=module.get_inputs(),
66+
compile_spec=common.get_tosa_compile_spec(),
67+
)
68+
.quantize()
69+
.export()
70+
.to_edge()
71+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
72+
.run_passes(test_pass_stage)
73+
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
74+
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
75+
)

0 commit comments

Comments
 (0)