9
9
import shutil
10
10
import unittest
11
11
12
- from typing import Optional , Tuple
12
+ from typing import Tuple
13
13
14
14
import torch
15
15
from executorch .backends .arm .test .test_models import TosaProfile
30
30
31
31
class TestSimpleAdd (unittest .TestCase ):
32
32
class Add (torch .nn .Module ):
33
+ test_parameters = [
34
+ (torch .ones (5 ),),
35
+ (3 * torch .ones (8 ),),
36
+ (10 * torch .randn (8 ),),
37
+ ]
38
+
33
39
def __init__ (self ):
34
40
super ().__init__ ()
35
41
self .permute_memory_to_nhwc = False
@@ -38,6 +44,13 @@ def forward(self, x):
38
44
return x + x
39
45
40
46
class Add2 (torch .nn .Module ):
47
+ test_parameters = [
48
+ (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 4 )),
49
+ (torch .randn (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 )),
50
+ (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 )),
51
+ (10000 * torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 )),
52
+ ]
53
+
41
54
def __init__ (self ):
42
55
super ().__init__ ()
43
56
self .permute_memory_to_nhwc = False
@@ -118,40 +131,40 @@ def _test_add_u55_BI_pipeline(
118
131
.to_executorch ()
119
132
)
120
133
121
- def test_add_tosa_MI (self ):
122
- test_data = (torch .randn (4 , 4 , 4 ),)
134
+ @parameterized .expand (Add .test_parameters )
135
+ def test_add_tosa_MI (self , test_data : torch .Tensor ):
136
+ test_data = (test_data ,)
123
137
self ._test_add_tosa_MI_pipeline (self .Add (), test_data )
124
138
125
- @parameterized .expand (
126
- [
127
- (torch .ones (5 ),), # test_data
128
- (3 * torch .ones (8 ),),
129
- ]
130
- )
131
- def test_add_tosa_BI (self , test_data : Optional [Tuple [torch .Tensor ]]):
139
+ @parameterized .expand (Add .test_parameters )
140
+ def test_add_tosa_BI (self , test_data : torch .Tensor ):
132
141
test_data = (test_data ,)
133
142
self ._test_add_tosa_BI_pipeline (self .Add (), test_data )
134
143
144
+ @parameterized .expand (Add .test_parameters )
135
145
@unittest .skipIf (
136
146
not VELA_INSTALLED ,
137
147
"There is no point in running U55 tests if the Vela tool is not installed" ,
138
148
)
139
- def test_add_u55_BI (self ):
140
- test_data = (3 * torch . ones ( 5 ) ,)
149
+ def test_add_u55_BI (self , test_data : torch . Tensor ):
150
+ test_data = (test_data ,)
141
151
self ._test_add_u55_BI_pipeline (self .Add (), test_data )
142
152
143
- def test_add2_tosa_MI (self ):
144
- test_data = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 ))
153
+ @parameterized .expand (Add2 .test_parameters )
154
+ def test_add2_tosa_MI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
155
+ test_data = (operand1 , operand2 )
145
156
self ._test_add_tosa_MI_pipeline (self .Add2 (), test_data )
146
157
147
- def test_add2_tosa_BI (self ):
148
- test_data = (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 ))
158
+ @parameterized .expand (Add2 .test_parameters )
159
+ def test_add2_tosa_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
160
+ test_data = (operand1 , operand2 )
149
161
self ._test_add_tosa_BI_pipeline (self .Add2 (), test_data )
150
162
163
+ @parameterized .expand (Add2 .test_parameters )
151
164
@unittest .skipIf (
152
165
not VELA_INSTALLED ,
153
166
"There is no point in running U55 tests if the Vela tool is not installed" ,
154
167
)
155
- def test_add2_u55_BI (self ):
156
- test_data = (torch . ones ( 1 , 1 , 4 , 4 ), torch . ones ( 1 , 1 , 4 , 1 ) )
168
+ def test_add2_u55_BI (self , operand1 : torch . Tensor , operand2 : torch . Tensor ):
169
+ test_data = (operand1 , operand2 )
157
170
self ._test_add_u55_BI_pipeline (self .Add2 (), test_data )
0 commit comments