Skip to content

Add pass to convert special case of mean.dim to averagepool2d #4900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions backends/arm/operators/op_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common


@register_node_visitor
Expand All @@ -30,29 +29,4 @@ def define_node(
is_quant_node: bool,
) -> None:

input_tensor = inputs[0]
dim = node.args[1]
keep_dim = node.args[2]

# mean.dim(-1, -2) is the same as avg_pool2d when just computing mean over HW dimensions.
# Since tosa doesn't have mean.dim operation, lowers it to average pooling instead.
if dim == [-1, -2]:
if keep_dim is True:
# Given the shape format of input is (N, C, H, W)
kernel_size = [input_tensor.shape[2], input_tensor.shape[3]]
stride = [1, 1]
padding = [0, 0, 0, 0]

build_avg_pool_2d_common(
node,
tosa_graph,
input_tensor,
kernel_size,
stride,
padding,
is_quant_node,
output,
)
return

raise AssertionError("unsupported")
4 changes: 4 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from executorch.backends.arm.passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand All @@ -33,6 +36,7 @@ def transform_to_backend_pipeline(
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(ConvertSplitToSlicePass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
Expand Down
52 changes: 52 additions & 0 deletions backends/arm/passes/meandim_to_averagepool_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, cast, Dict, Tuple

import torch.fx

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue

Argument = Any


class ConvertMeanDimToAveragePool(ExportPass):
"""
Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
"""

def call_operator(
self,
op: torch.fx.node.Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != exir_ops.edge.aten.mean.dim:
return super().call_operator(op, args, kwargs, meta)

input_value = cast(ProxyValue, args[0])
dim = cast(list, args[1])
keep_dim = cast(bool, args[2]) if len(args) > 2 else False

# averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True
# so check the dim argument for this case
if dim == [-1, -2] and keep_dim is True:
# Given the shape format of input is (N, C, H, W)
kernel_size = [
input_value.to_tensor().size()[2],
input_value.to_tensor().size()[3],
]
stride = [1, 1]
return super().call_operator(
exir_ops.edge.aten.avg_pool2d.default,
(input_value, kernel_size, stride),
{},
meta,
)
else:
return super().call_operator(op, args, kwargs, meta)
7 changes: 6 additions & 1 deletion backends/arm/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def _test_meandim_tosa_u55_BI_pipeline(
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_mean_dim",
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)
Expand Down
75 changes: 75 additions & 0 deletions backends/arm/test/passes/test_meandim_to_averagepool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester

from executorch.backends.xnnpack.test.tester.tester import RunPasses


class MeanDim(torch.nn.Module):
def forward(self, x):
return torch.mean(x, dim=[-1, -2], keepdim=True)

def get_inputs(self):
return (torch.rand(1, 1280, 7, 7),)


class MeanDim2(torch.nn.Module):
def forward(self, x):
return torch.mean(x, dim=1)

def get_inputs(self):
return (torch.rand(1, 1280, 7, 7),)


class TestMeandimToAveragePool2dPass(unittest.TestCase):
"""
Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d
for the special case where dim is [-1, -2] and keepdim is True.
"""

def test_tosa_BI_meandim_to_averagepool(self):
module = MeanDim()
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
(
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.to_edge()
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I see you are doing here, thanks.

.run_passes(test_pass_stage)
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
)

def test_tosa_BI_meandim_no_modification(self):
module = MeanDim2()
test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
(
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.to_edge()
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
.run_passes(test_pass_stage)
.check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
.check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
)
Loading