|
9 | 9 | import shutil
|
10 | 10 | import unittest
|
11 | 11 |
|
12 |
| -from typing import Optional, Tuple |
| 12 | +from typing import Tuple |
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | from executorch.backends.arm.test.test_models import TosaProfile
|
|
30 | 30 |
|
31 | 31 | class TestSimpleAdd(unittest.TestCase):
|
32 | 32 | class Add(torch.nn.Module):
|
| 33 | + test_parameters = [ |
| 34 | + (torch.ones(5),), |
| 35 | + (3 * torch.ones(8),), |
| 36 | + (10 * torch.randn(8),), |
| 37 | + ] |
| 38 | + |
33 | 39 | def __init__(self):
|
34 | 40 | super().__init__()
|
35 | 41 |
|
36 | 42 | def forward(self, x):
|
37 | 43 | return x + x
|
38 | 44 |
|
39 | 45 | class Add2(torch.nn.Module):
|
| 46 | + test_parameters = [ |
| 47 | + (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 4)), |
| 48 | + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), |
| 49 | + (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), |
| 50 | + ] |
| 51 | + |
40 | 52 | def __init__(self):
|
41 | 53 | super().__init__()
|
42 | 54 |
|
@@ -114,40 +126,40 @@ def _test_add_u55_BI_pipeline(
|
114 | 126 | .to_executorch()
|
115 | 127 | )
|
116 | 128 |
|
117 |
| - def test_add_tosa_MI(self): |
118 |
| - test_data = (torch.randn(4, 4, 4),) |
| 129 | + @parameterized.expand(Add.test_parameters) |
| 130 | + def test_add_tosa_MI(self, test_data: torch.Tensor): |
| 131 | + test_data = (test_data,) |
119 | 132 | self._test_add_tosa_MI_pipeline(self.Add(), test_data)
|
120 | 133 |
|
121 |
| - @parameterized.expand( |
122 |
| - [ |
123 |
| - (torch.ones(5),), # test_data |
124 |
| - (3 * torch.ones(8),), |
125 |
| - ] |
126 |
| - ) |
127 |
| - def test_add_tosa_BI(self, test_data: Optional[Tuple[torch.Tensor]]): |
| 134 | + @parameterized.expand(Add.test_parameters) |
| 135 | + def test_add_tosa_BI(self, test_data: torch.Tensor): |
128 | 136 | test_data = (test_data,)
|
129 | 137 | self._test_add_tosa_BI_pipeline(self.Add(), test_data)
|
130 | 138 |
|
131 | 139 | @unittest.skipIf(
|
132 | 140 | not VELA_INSTALLED,
|
133 | 141 | "There is no point in running U55 tests if the Vela tool is not installed",
|
134 | 142 | )
|
135 |
| - def test_add_u55_BI(self): |
136 |
| - test_data = (3 * torch.ones(5),) |
| 143 | + @parameterized.expand(Add.test_parameters) |
| 144 | + def test_add_u55_BI(self, test_data: torch.Tensor): |
| 145 | + test_data = (test_data,) |
137 | 146 | self._test_add_u55_BI_pipeline(self.Add(), test_data)
|
138 | 147 |
|
139 |
| - def test_add2_tosa_MI(self): |
140 |
| - test_data = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) |
| 148 | + @parameterized.expand(Add2.test_parameters) |
| 149 | + def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
| 150 | + test_data = (operand1, operand2) |
141 | 151 | self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
|
142 | 152 |
|
143 |
| - def test_add2_tosa_BI(self): |
144 |
| - test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1)) |
| 153 | + @parameterized.expand(Add2.test_parameters) |
| 154 | + def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
| 155 | + test_data = (operand1, operand2) |
145 | 156 | self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
|
146 | 157 |
|
147 | 158 | @unittest.skipIf(
|
148 | 159 | not VELA_INSTALLED,
|
149 | 160 | "There is no point in running U55 tests if the Vela tool is not installed",
|
150 | 161 | )
|
151 |
| - def test_add2_u55_BI(self): |
152 |
| - test_data = (torch.ones(1, 1, 4, 4), torch.ones(1, 1, 4, 1)) |
| 162 | + @parameterized.expand(Add2.test_parameters) |
| 163 | + def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): |
| 164 | + test_data = (operand1, operand2) |
153 | 165 | self._test_add_u55_BI_pipeline(self.Add2(), test_data)
|
0 commit comments