6
6
7
7
import logging
8
8
import os
9
+ import shutil
9
10
import tempfile
10
11
import unittest
11
12
@@ -126,8 +127,62 @@ def test_numerical_diff_prints(self):
126
127
self .fail ()
127
128
128
129
129
- class TestDumpOperatorsAndDtypes (unittest .TestCase ):
130
- def test_dump_ops_and_dtypes (self ):
130
+ def test_dump_ops_and_dtypes ():
131
+ model = Linear (20 , 30 )
132
+ (
133
+ ArmTester (
134
+ model ,
135
+ example_inputs = model .get_inputs (),
136
+ compile_spec = common .get_tosa_compile_spec (),
137
+ )
138
+ .quantize ()
139
+ .dump_dtype_distribution ()
140
+ .dump_operator_distribution ()
141
+ .export ()
142
+ .dump_dtype_distribution ()
143
+ .dump_operator_distribution ()
144
+ .to_edge ()
145
+ .dump_dtype_distribution ()
146
+ .dump_operator_distribution ()
147
+ .partition ()
148
+ .dump_dtype_distribution ()
149
+ .dump_operator_distribution ()
150
+ )
151
+ # Just test that there are no execptions.
152
+
153
+
154
+ def test_dump_ops_and_dtypes_parseable ():
155
+ model = Linear (20 , 30 )
156
+ (
157
+ ArmTester (
158
+ model ,
159
+ example_inputs = model .get_inputs (),
160
+ compile_spec = common .get_tosa_compile_spec (),
161
+ )
162
+ .quantize ()
163
+ .dump_dtype_distribution (print_table = False )
164
+ .dump_operator_distribution (print_table = False )
165
+ .export ()
166
+ .dump_dtype_distribution (print_table = False )
167
+ .dump_operator_distribution (print_table = False )
168
+ .to_edge ()
169
+ .dump_dtype_distribution (print_table = False )
170
+ .dump_operator_distribution (print_table = False )
171
+ .partition ()
172
+ .dump_dtype_distribution (print_table = False )
173
+ .dump_operator_distribution (print_table = False )
174
+ )
175
+ # Just test that there are no execptions.
176
+
177
+
178
+ class TestCollateTosaTests (unittest .TestCase ):
179
+ """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""
180
+
181
+ def test_collate_tosa_BI_tests (self ):
182
+ # Set the environment variable to trigger the collation of TOSA tests
183
+ os .environ ["TOSA_TESTCASES_BASE_PATH" ] = "test_collate_tosa_tests"
184
+ # Clear out the directory
185
+
131
186
model = Linear (20 , 30 )
132
187
(
133
188
ArmTester (
@@ -136,16 +191,59 @@ def test_dump_ops_and_dtypes(self):
136
191
compile_spec = common .get_tosa_compile_spec (),
137
192
)
138
193
.quantize ()
139
- .dump_dtype_distribution ()
140
- .dump_operator_distribution ()
141
194
.export ()
142
- .dump_dtype_distribution ()
143
- .dump_operator_distribution ()
144
195
.to_edge ()
145
- .dump_dtype_distribution ()
146
- .dump_operator_distribution ()
147
196
.partition ()
148
- .dump_dtype_distribution ()
149
- .dump_operator_distribution ()
197
+ .to_executorch ()
198
+ )
199
+ # test that the output directory is created and contains the expected files
200
+ assert os .path .exists (
201
+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
202
+ )
203
+ assert os .path .exists (
204
+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa"
205
+ )
206
+ assert os .path .exists (
207
+ "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json"
208
+ )
209
+
210
+ os .environ .pop ("TOSA_TESTCASES_BASE_PATH" )
211
+ shutil .rmtree ("test_collate_tosa_tests" , ignore_errors = True )
212
+
213
+
214
+ def test_dump_tosa_ops (caplog ):
215
+ caplog .set_level (logging .INFO )
216
+ model = Linear (20 , 30 )
217
+ (
218
+ ArmTester (
219
+ model ,
220
+ example_inputs = model .get_inputs (),
221
+ compile_spec = common .get_tosa_compile_spec (),
150
222
)
151
- # Just test that there are no execeptions.
223
+ .quantize ()
224
+ .export ()
225
+ .to_edge ()
226
+ .partition ()
227
+ .dump_operator_distribution ()
228
+ )
229
+ assert "TOSA operators:" in caplog .text
230
+
231
+
232
+ def test_fail_dump_tosa_ops (caplog ):
233
+ caplog .set_level (logging .INFO )
234
+
235
+ class Add (torch .nn .Module ):
236
+ def forward (self , x ):
237
+ return x + x
238
+
239
+ model = Add ()
240
+ compile_spec = common .get_u55_compile_spec ()
241
+ (
242
+ ArmTester (model , example_inputs = (torch .ones (5 ),), compile_spec = compile_spec )
243
+ .quantize ()
244
+ .export ()
245
+ .to_edge ()
246
+ .partition ()
247
+ .dump_operator_distribution ()
248
+ )
249
+ assert "Can not get operator distribution for Vela command stream." in caplog .text
0 commit comments