Skip to content

Commit b607800

Browse files
Arm backend: Enable numerical testing of quantized rshift (#8371)
Enable numerical testing of quantized rshift Signed-off-by: Oscar Andersson <[email protected]>
1 parent 0740a11 commit b607800

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

backends/arm/test/ops/test_rshift.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -7,22 +7,20 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test import common, conftest
1111
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1212
from parameterized import parameterized
1313

1414

1515
class TestRshift(unittest.TestCase):
16-
"""
17-
Tests arithmetic right shift
18-
"""
16+
"""Tests arithmetic right shift"""
1917

2018
class Rshift(torch.nn.Module):
2119
test_data = [
2220
((torch.IntTensor(5, 5), 2),),
2321
((torch.IntTensor(1, 2, 3, 4), 3),),
22+
((torch.CharTensor(1, 12, 3, 4), 1),),
2423
((torch.ShortTensor(1, 5, 3, 4), 5),),
25-
((torch.CharTensor(10, 12, 3, 4), 1),),
2624
]
2725

2826
def forward(self, x: torch.Tensor, shift: int):
@@ -52,8 +50,7 @@ def _test_rshift_tosa_BI(self, test_data):
5250
.export()
5351
.to_edge_transform_and_lower()
5452
.to_executorch()
55-
# TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
56-
# .run_method_and_compare_outputs(inputs=test_data)
53+
.run_method_and_compare_outputs(inputs=test_data)
5754
)
5855

5956
def _test_rshift_ethosu_BI(self, test_data, compile_spec):
@@ -67,6 +64,7 @@ def _test_rshift_ethosu_BI(self, test_data, compile_spec):
6764
.export()
6865
.to_edge_transform_and_lower()
6966
.to_executorch()
67+
.serialize()
7068
)
7169

7270
@parameterized.expand(Rshift.test_data)
@@ -77,14 +75,18 @@ def test_rshift_tosa_MI(self, test_data):
7775
def test_rshift_tosa_BI(self, test_data):
7876
self._test_rshift_tosa_BI(test_data)
7977

80-
# TODO Enable FVP testing
81-
@parameterized.expand(Rshift.test_data)
78+
# TODO: MLETORCH-644 - Add support for INT16 input/output
79+
@parameterized.expand(Rshift.test_data[:-1])
8280
def test_rshift_u55_BI(self, test_data):
8381
compile_spec = common.get_u55_compile_spec()
84-
self._test_rshift_ethosu_BI(test_data, compile_spec)
82+
tester = self._test_rshift_ethosu_BI(test_data, compile_spec)
83+
if conftest.is_option_enabled("corstone_fvp"):
84+
tester.run_method_and_compare_outputs(atol=1, inputs=test_data)
8585

86-
# TODO Enable FVP testing
87-
@parameterized.expand(Rshift.test_data)
86+
# TODO: MLETORCH-644 - Add support for INT16 input/output
87+
@parameterized.expand(Rshift.test_data[:-1])
8888
def test_rshift_u85_BI(self, test_data):
8989
compile_spec = common.get_u85_compile_spec()
90-
self._test_rshift_ethosu_BI(test_data, compile_spec)
90+
tester = self._test_rshift_ethosu_BI(test_data, compile_spec)
91+
if conftest.is_option_enabled("corstone_fvp"):
92+
tester.run_method_and_compare_outputs(inputs=test_data)

0 commit comments

Comments
 (0)