Skip to content

Port Arm backend op unittests: conv op + other ops #2647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/arm/test/arm_tosa_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

SUPPORTED_BI_TEST_LIST = [
"simple_add_broadcast",
"block_bottleneck_residual",
]


Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

torch.manual_seed(42)


class TestSimpleAdd(unittest.TestCase):
class Add(torch.nn.Module):
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

torch.manual_seed(42)


class Conv2d(torch.nn.Module):
"""
Expand Down
272 changes: 272 additions & 0 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.test_models import TosaProfile
from executorch.backends.arm.test.tester.arm_tester import ArmBackendSelector, ArmTester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

"""
This file contain unit tests where conv are combined with other ops.
"""


class ComboBlockBottleneckResidual(torch.nn.Module):
# This is the essence of MobileNetV2. Ref: https://arxiv.org/abs/1801.04381
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
]

def __init__(self):
super().__init__()
# (t, c, n, s) = (6, 96, 1, 1)
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
self.pointwise_conv2d = torch.nn.Conv2d(
in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1
) ## (1, 384, 81, 81)
self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False)
self.relu6 = torch.nn.ReLU6()

# 2. 3x3 DepthwiseConv2d + ReLu6
self.depthwise_conv2d = torch.nn.Conv2d(
in_channels=384,
out_channels=384,
kernel_size=3,
padding=1,
stride=1,
groups=384,
) ## (1, 384, H, W)

# 3. Linear 1x1 Conv2d
self.pointwise_conv2d_linear = torch.nn.Conv2d(
in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1
) ## (1, 64, 81, 81)

def get_inputs(self) -> Tuple[torch.Tensor]:
return (torch.randn(1, 64, 81, 81),)

def forward(self, x):
input = x
# 1x1 CONV2d + ReLU6 (Pointwise)
x = self.pointwise_conv2d(x)
x = self.batch_norm2d_16(x)
x = self.relu6(x)

# 3x3 DepthwiseConv2d + ReLu6
x = self.depthwise_conv2d(x)
x = self.batch_norm2d_16(x)
x = self.relu6(x)

# Linear 1x1 Conv2d
x = self.pointwise_conv2d_linear(x)

# Final Residual Connection
x = x + input

return x


class ComboConv2dMeandim(torch.nn.Module):
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_mean_dim",
]

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=10, kernel_size=5, stride=1, bias=False
)
# will be specialized to aten.mean.dim
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))

def get_inputs(self) -> Tuple[torch.Tensor]:
return (torch.randn(1, 3, 128, 128),)

def forward(self, x):
x = self.conv2d(x)
return self.adaptive_avg_pool2d(x)


class ComboConvBatchnormRelu(torch.nn.Module):
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
]

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
self.relu6 = torch.nn.ReLU6()

def get_inputs(self) -> Tuple[torch.Tensor]:
return (torch.randn(1, 3, 256, 256),)

def forward(self, x):
x = self.conv2d(x)
x = self.batch_norm2d(x)
x = self.relu6(x)
return x


class TestConvCombos(unittest.TestCase):
def _test_conv_combo_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
tester = (
ArmTester(
module,
inputs=test_data,
profile=TosaProfile.MI,
backend=ArmBackendSelector.TOSA,
permute_memory_to_nhwc=True,
)
.export()
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(list(module.edge_op_list))
.to_executorch()
)
if common.TOSA_REF_MODEL_INSTALLED:
tester.run_method().compare_outputs()
else:
logger.warning(
"TOSA ref model tool not installed, skip numerical correctness tests"
)

def _test_conv_combo_tosa_BI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.Tensor],
atol: float = 1e-3,
rtol: float = 1e-3,
):
tester = (
ArmTester(
module,
inputs=test_data,
profile=TosaProfile.BI,
backend=ArmBackendSelector.TOSA,
permute_memory_to_nhwc=True,
)
.quantize()
.export()
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(list(module.edge_op_list))
.to_executorch()
)
if common.TOSA_REF_MODEL_INSTALLED:
tester.run_method().compare_outputs(atol=atol, rtol=rtol, qtol=1)
else:
logger.warning(
"TOSA ref model tool not installed, skip numerical correctness tests"
)

def _test_conv_combo_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
):
(
ArmTester(
module,
inputs=test_data,
profile=TosaProfile.BI,
backend=ArmBackendSelector.ETHOS_U55,
permute_memory_to_nhwc=True,
)
.quantize()
.export()
.to_edge()
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(list(module.edge_op_list))
.to_executorch()
)

####################
## Conv + meandim ##
####################
def test_conv_meandim_tosa_MI(self):
model = ComboConv2dMeandim()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

def test_conv_meandim_tosa_BI(self):
model = ComboConv2dMeandim()
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@unittest.skipIf(
not common.VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_conv_meandim_u55_BI(self):
model = ComboConv2dMeandim()
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

##############################
## Conv + batch norm + relu ##
##############################
def test_conv_batchnorm_relu_tosa_MI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
# testcase as well (and also not tested). For now, just increase the
# tolerance, such that we don't skip the test entirely (i.e. we maintain
# functionality).
def test_conv_batchnorm_relu_tosa_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)

@unittest.skipIf(
not common.VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_conv_batchnorm_relu_u55_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

###############################
## Block bottleneck residual ##
###############################
def test_block_bottleneck_residual_tosa_MI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
# testcase as well. For now, just increase the tolerance, such that
# we don't skip the test entirely (i.e. we maintain functionality).
def test_block_bottleneck_residual_tosa_BI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)

@unittest.skipIf(
not common.VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_block_bottleneck_residual_u55_BI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

torch.manual_seed(42)

"""
The configuration when
groups == in_channels and
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

torch.manual_seed(42)

test_data_suite_rank1 = [
# (test_name, test_data, out_features)
Expand Down
Loading