Skip to content

Commit 9225998

Browse files
committed
Add pass to convert special case of mean.dim to averagepool
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I5f2e26ee674cee9df5ffec3d4923466dea4ed463
1 parent ce4917c commit 9225998

File tree

5 files changed

+65
-28
lines changed

5 files changed

+65
-28
lines changed

backends/arm/operators/op_mean_dim.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
register_node_visitor,
1212
)
1313
from executorch.backends.arm.tosa_mapping import TosaArg
14-
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common
1514

1615

1716
@register_node_visitor
@@ -30,29 +29,4 @@ def define_node(
3029
is_quant_node: bool,
3130
) -> None:
3231

33-
input_tensor = inputs[0]
34-
dim = node.args[1]
35-
keep_dim = node.args[2]
36-
37-
# mean.dim(-1, -2) is the same as avg_pool2d when just computing mean over HW dimensions.
38-
# Since tosa doesn't have mean.dim operation, lowers it to average pooling instead.
39-
if dim == [-1, -2]:
40-
if keep_dim is True:
41-
# Given the shape format of input is (N, C, H, W)
42-
kernel_size = [input_tensor.shape[2], input_tensor.shape[3]]
43-
stride = [1, 1]
44-
padding = [0, 0, 0, 0]
45-
46-
build_avg_pool_2d_common(
47-
node,
48-
tosa_graph,
49-
input_tensor,
50-
kernel_size,
51-
stride,
52-
padding,
53-
is_quant_node,
54-
output,
55-
)
56-
return
57-
5832
raise AssertionError("unsupported")

backends/arm/passes/annotate_channels_last_dim_order_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import collections.abc
8+
79
import torch
810
from executorch.backends.arm.tosa_quant_utils import dq_op
911
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
@@ -46,7 +48,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4648
NHWC_Order = (0, 2, 3, 1)
4749
HWCM_Order = (2, 3, 0, 1)
4850
for node in graph_module.graph.nodes:
49-
if isinstance(node.meta["val"], tuple):
51+
if isinstance(node.meta["val"], collections.abc.Sequence):
5052
node_data = node.meta["val"][0].data
5153
else:
5254
node_data = node.meta["val"].data

backends/arm/passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from executorch.backends.arm.passes.convert_split_to_slice import (
1616
ConvertSplitToSlicePass,
1717
)
18+
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
19+
ConvertMeanDimToAveragePool,
20+
)
1821
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
1922
from executorch.exir.backend.compile_spec_schema import CompileSpec
2023
from executorch.exir.pass_manager import PassManager
@@ -31,6 +34,7 @@ def transform_to_backend_pipeline(
3134
"""Apply passes before transforming program to backend"""
3235
self.add_pass(RemoveClonePass())
3336
self.add_pass(ConvertExpandCopyToRepeatPass())
37+
self.add_pass(ConvertMeanDimToAveragePool())
3438
self.add_pass(ConvertSplitToSlicePass())
3539
for spec in compile_spec:
3640
if spec.key == "permute_memory_format":
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, cast, Dict, Tuple
8+
9+
import torch.fx
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
13+
14+
Argument = Any
15+
16+
17+
class ConvertMeanDimToAveragePool(ExportPass):
18+
"""
19+
Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
20+
"""
21+
22+
def call_operator(
23+
self,
24+
op: torch.fx.node.Target,
25+
args: Tuple[Argument, ...],
26+
kwargs: Dict[str, Argument],
27+
meta: NodeMetadata,
28+
) -> ProxyValue:
29+
if op != exir_ops.edge.aten.mean.dim:
30+
return super().call_operator(op, args, kwargs, meta)
31+
32+
input_value = cast(ProxyValue, args[0])
33+
dim = cast(list, args[1])
34+
keep_dim = cast(bool, args[2])
35+
36+
# averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True
37+
# so check the dim argument for this case
38+
if dim == [-1, -2] and keep_dim is True:
39+
# Given the shape format of input is (N, C, H, W)
40+
kernel_size = [
41+
input_value.to_tensor().size()[2],
42+
input_value.to_tensor().size()[3],
43+
]
44+
stride = [1, 1]
45+
return super().call_operator(
46+
exir_ops.edge.aten.avg_pool2d.default,
47+
(input_value, kernel_size, stride),
48+
{},
49+
meta,
50+
)
51+
else:
52+
return super().call_operator(op, args, kwargs, meta)

backends/arm/test/ops/test_mean_dim.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,12 @@ def _test_meandim_tosa_u55_BI_pipeline(
106106
.check(["torch.ops.quantized_decomposed"])
107107
.to_edge()
108108
.partition()
109-
.check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
109+
.check_not(
110+
[
111+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
112+
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
113+
]
114+
)
110115
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
111116
.to_executorch()
112117
)

0 commit comments

Comments
 (0)