9
9
import shutil
10
10
import unittest
11
11
12
- from typing import Optional , Tuple
12
+ from typing import Tuple
13
13
14
14
import torch
15
15
from executorch .backends .arm .test .test_models import TosaProfile
30
30
31
31
class TestSimpleAdd (unittest .TestCase ):
32
32
class Add (torch .nn .Module ):
33
+ test_parameters = [
34
+ (torch .ones (5 ),),
35
+ (3 * torch .ones (8 ),),
36
+ (10 * torch .randn (8 ),),
37
+ ]
38
+
33
39
def __init__ (self ):
34
40
super ().__init__ ()
35
41
36
42
def forward (self , x ):
37
43
return x + x
38
44
39
45
class Add2 (torch .nn .Module ):
46
+ test_parameters = [
47
+ (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 4 )),
48
+ (torch .randn (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 )),
49
+ (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 )),
50
+ ]
51
+
40
52
def __init__ (self ):
41
53
super ().__init__ ()
42
54
@@ -88,7 +100,7 @@ def _test_add_tosa_BI_pipeline(
88
100
.to_executorch ()
89
101
)
90
102
if TOSA_REF_MODEL_INSTALLED :
91
- tester .run_method ().compare_outputs ()
103
+ tester .run_method ().compare_outputs (qtol = 1 )
92
104
else :
93
105
logger .warning (
94
106
"TOSA ref model tool not installed, skip numerical correctness tests"
@@ -114,42 +126,42 @@ def _test_add_u55_BI_pipeline(
114
126
.to_executorch ()
115
127
)
116
128
117
- def test_add_tosa_MI (self ):
118
- test_data = (torch .randn (4 , 4 , 4 ),)
129
+ @parameterized .expand (Add .test_parameters )
130
+ def test_add_tosa_MI (self , test_data : torch .Tensor ):
131
+ test_data = (test_data ,)
119
132
self ._test_add_tosa_MI_pipeline (self .Add (), test_data )
120
133
121
134
# TODO: Will this type of parametrization be supported? pytest seem
122
135
# have issue with it.
123
- @parameterized .expand (
124
- [
125
- (torch .ones (5 ),), # test_data
126
- (3 * torch .ones (8 ),),
127
- ]
128
- )
129
- def test_add_tosa_BI (self , test_data : Optional [Tuple [torch .Tensor ]]):
136
+ @parameterized .expand (Add .test_parameters )
137
+ def test_add_tosa_BI (self , test_data : torch .Tensor ):
130
138
test_data = (test_data ,)
131
139
self ._test_add_tosa_BI_pipeline (self .Add (), test_data )
132
140
133
141
@unittest .skipIf (
134
142
not VELA_INSTALLED ,
135
143
"There is no point in running U55 tests if the Vela tool is not installed" ,
136
144
)
137
- def test_add_u55_BI (self ):
138
- test_data = (3 * torch .ones (5 ),)
145
+ @parameterized .expand (Add .test_parameters )
146
+ def test_add_u55_BI (self , test_data : torch .Tensor ):
147
+ test_data = (test_data ,)
139
148
self ._test_add_u55_BI_pipeline (self .Add (), test_data )
140
149
141
- def test_add2_tosa_MI (self ):
142
- test_data = (torch .randn (1 , 1 , 4 , 4 ), torch .randn (1 , 1 , 4 , 1 ))
150
+ @parameterized .expand (Add2 .test_parameters )
151
+ def test_add2_tosa_MI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
152
+ test_data = (operand1 , operand2 )
143
153
self ._test_add_tosa_MI_pipeline (self .Add2 (), test_data )
144
154
145
- def test_add2_tosa_BI (self ):
146
- test_data = (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 ))
155
+ @parameterized .expand (Add2 .test_parameters )
156
+ def test_add2_tosa_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
157
+ test_data = (operand1 , operand2 )
147
158
self ._test_add_tosa_BI_pipeline (self .Add2 (), test_data )
148
159
149
160
@unittest .skipIf (
150
161
not VELA_INSTALLED ,
151
162
"There is no point in running U55 tests if the Vela tool is not installed" ,
152
163
)
153
- def test_add2_u55_BI (self ):
154
- test_data = (torch .ones (1 , 1 , 4 , 4 ), torch .ones (1 , 1 , 4 , 1 ))
164
+ @parameterized .expand (Add2 .test_parameters )
165
+ def test_add2_u55_BI (self , operand1 : torch .Tensor , operand2 : torch .Tensor ):
166
+ test_data = (operand1 , operand2 )
155
167
self ._test_add_u55_BI_pipeline (self .Add2 (), test_data )
0 commit comments