9
9
import shutil
10
10
import unittest
11
11
12
- from typing import Optional , Tuple
12
+ from typing import Tuple
13
13
14
14
import torch
15
15
from executorch .backends .arm .test .test_models import TosaProfile
30
30
31
31
class TestSimpleAdd (unittest .TestCase ):
32
32
class Add (torch .nn .Module ):
33
+ test_parameters = [
34
+ (torch .ones (5 ),),
35
+ (3 * torch .ones (8 ),),
36
+ (10 * torch .randn (8 ),),
37
+ ]
38
+
33
39
def __init__ (self ):
34
40
super ().__init__ ()
35
41
self .permute_memory_to_nhwc = False
@@ -38,6 +44,12 @@ def forward(self, x):
38
44
return x + x
39
45
40
46
class Add2 (torch .nn .Module ):
47
+ test_parameters = [
48
+ (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 4 )),
49
+ (torch .randn (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 )),
50
+ (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 )),
51
+ ]
52
+
41
53
def __init__ (self ):
42
54
super ().__init__ ()
43
55
self .permute_memory_to_nhwc = False
@@ -118,40 +130,40 @@ def _test_add_u55_BI_pipeline(
118
130
.to_executorch ()
119
131
)
120
132
121
- def test_add_tosa_MI (self ):
122
- test_data = (torch .randn (4 , 4 , 4 ),)
133
+ @parameterized .expand (Add .test_parameters )
134
+ def test_add_tosa_MI (self , test_data : torch .Tensor ):
135
+ test_data = (test_data ,)
123
136
self ._test_add_tosa_MI_pipeline (self .Add (), test_data )
124
137
125
- @parameterized .expand (
126
- [
127
- (torch .ones (5 ),), # test_data
128
- (3 * torch .ones (8 ),),
129
- ]
130
- )
131
- def test_add_tosa_BI (self , test_data : Optional [Tuple [torch .Tensor ]]):
138
+ @parameterized .expand (Add .test_parameters )
139
+ def test_add_tosa_BI (self , test_data : torch .Tensor ):
132
140
test_data = (test_data ,)
133
141
self ._test_add_tosa_BI_pipeline (self .Add (), test_data )
134
142
135
143
@unittest .skipIf (
136
144
not VELA_INSTALLED ,
137
145
"There is no point in running U55 tests if the Vela tool is not installed" ,
138
146
)
139
- def test_add_u55_BI (self ):
140
- test_data = (3 * torch .ones (5 ),)
147
+ @parameterized .expand (Add .test_parameters )
148
+ def test_add_u55_BI (self , test_data : torch .Tensor ):
149
+ test_data = (test_data ,)
141
150
self ._test_add_u55_BI_pipeline (self .Add (), test_data )
142
151
143
- def test_add2_tosa_MI (self ):
144
- test_data = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 ))
152
+ @parameterized .expand (Add2 .test_parameters )
153
+ def test_add2_tosa_MI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
154
+ test_data = (operand1 , operand2 )
145
155
self ._test_add_tosa_MI_pipeline (self .Add2 (), test_data )
146
156
147
- def test_add2_tosa_BI (self ):
148
- test_data = (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 ))
157
+ @parameterized .expand (Add2 .test_parameters )
158
+ def test_add2_tosa_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
159
+ test_data = (operand1 , operand2 )
149
160
self ._test_add_tosa_BI_pipeline (self .Add2 (), test_data )
150
161
151
162
@unittest .skipIf (
152
163
not VELA_INSTALLED ,
153
164
"There is no point in running U55 tests if the Vela tool is not installed" ,
154
165
)
155
- def test_add2_u55_BI (self ):
156
- test_data = (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 ))
166
+ @parameterized .expand (Add2 .test_parameters )
167
+ def test_add2_u55_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
168
+ test_data = (operand1 , operand2 )
157
169
self ._test_add_u55_BI_pipeline (self .Add2 (), test_data )
0 commit comments