1
- # Copyright 2024 Arm Limited and/or its affiliates.
1
+ # Copyright 2024-2025 Arm Limited and/or its affiliates.
2
2
# All rights reserved.
3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
7
7
import unittest
8
8
9
9
import torch
10
- from executorch .backends .arm .test import common
10
+ from executorch .backends .arm .test import common , conftest
11
11
from executorch .backends .arm .test .tester .arm_tester import ArmTester
12
12
from parameterized import parameterized
13
13
14
14
15
15
class TestRshift (unittest .TestCase ):
16
- """
17
- Tests arithmetic right shift
18
- """
16
+ """Tests arithmetic right shift"""
19
17
20
18
class Rshift (torch .nn .Module ):
21
19
test_data = [
22
20
((torch .IntTensor (5 , 5 ), 2 ),),
23
21
((torch .IntTensor (1 , 2 , 3 , 4 ), 3 ),),
22
+ ((torch .CharTensor (1 , 12 , 3 , 4 ), 1 ),),
24
23
((torch .ShortTensor (1 , 5 , 3 , 4 ), 5 ),),
25
- ((torch .CharTensor (10 , 12 , 3 , 4 ), 1 ),),
26
24
]
27
25
28
26
def forward (self , x : torch .Tensor , shift : int ):
@@ -52,8 +50,7 @@ def _test_rshift_tosa_BI(self, test_data):
52
50
.export ()
53
51
.to_edge_transform_and_lower ()
54
52
.to_executorch ()
55
- # TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
56
- # .run_method_and_compare_outputs(inputs=test_data)
53
+ .run_method_and_compare_outputs (inputs = test_data )
57
54
)
58
55
59
56
def _test_rshift_ethosu_BI (self , test_data , compile_spec ):
@@ -67,6 +64,7 @@ def _test_rshift_ethosu_BI(self, test_data, compile_spec):
67
64
.export ()
68
65
.to_edge_transform_and_lower ()
69
66
.to_executorch ()
67
+ .serialize ()
70
68
)
71
69
72
70
@parameterized .expand (Rshift .test_data )
@@ -77,14 +75,18 @@ def test_rshift_tosa_MI(self, test_data):
77
75
def test_rshift_tosa_BI (self , test_data ):
78
76
self ._test_rshift_tosa_BI (test_data )
79
77
80
- # TODO Enable FVP testing
81
- @parameterized .expand (Rshift .test_data )
78
+ # TODO: MLETORCH-644 - Add support for INT16 input/output
79
+ @parameterized .expand (Rshift .test_data [: - 1 ] )
82
80
def test_rshift_u55_BI (self , test_data ):
83
81
compile_spec = common .get_u55_compile_spec ()
84
- self ._test_rshift_ethosu_BI (test_data , compile_spec )
82
+ tester = self ._test_rshift_ethosu_BI (test_data , compile_spec )
83
+ if conftest .is_option_enabled ("corstone_fvp" ):
84
+ tester .run_method_and_compare_outputs (atol = 1 , inputs = test_data )
85
85
86
- # TODO Enable FVP testing
87
- @parameterized .expand (Rshift .test_data )
86
+ # TODO: MLETORCH-644 - Add support for INT16 input/output
87
+ @parameterized .expand (Rshift .test_data [: - 1 ] )
88
88
def test_rshift_u85_BI (self , test_data ):
89
89
compile_spec = common .get_u85_compile_spec ()
90
- self ._test_rshift_ethosu_BI (test_data , compile_spec )
90
+ tester = self ._test_rshift_ethosu_BI (test_data , compile_spec )
91
+ if conftest .is_option_enabled ("corstone_fvp" ):
92
+ tester .run_method_and_compare_outputs (inputs = test_data )
0 commit comments