Skip to content

Commit d516309

Browse files
perfacebook-github-bot
authored andcommitted
Add possibility to collect all TOSA tests to a specified path (#5028)
Summary: Done in order to collect test vectors for backend compilers. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I0fc6e4d6bfcccd6aae18847a9a33f76d3d19fe5f Pull Request resolved: #5028 Reviewed By: cccclai Differential Revision: D62242846 Pulled By: digantdesai fbshipit-source-id: 9ecfb7be3c5ed432a2cc36c2ea1eac7157ef6673
1 parent 341545c commit d516309

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

backends/arm/test/common.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,29 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
8686
return False
8787

8888

89+
def maybe_get_tosa_collate_path() -> str | None:
90+
"""
91+
Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
92+
path to the where to store the current tests if it is set.
93+
"""
94+
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
95+
if tosa_test_base:
96+
current_test = os.environ.get("PYTEST_CURRENT_TEST")
97+
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
98+
test_class = current_test.split("::")[1]
99+
test_name = current_test.split("::")[-1].split(" ")[0]
100+
if "BI" in test_name:
101+
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
102+
elif "MI" in test_name:
103+
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
104+
else:
105+
tosa_test_base = os.path.join(tosa_test_base, "other")
106+
107+
return os.path.join(tosa_test_base, test_class, test_name)
108+
109+
return None
110+
111+
89112
def get_tosa_compile_spec(
90113
permute_memory_to_nhwc=True, custom_path=None
91114
) -> list[CompileSpec]:
@@ -101,7 +124,13 @@ def get_tosa_compile_spec_unbuilt(
101124
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
102125
the compile spec before calling .build() to finalize it.
103126
"""
104-
intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_")
127+
if not custom_path:
128+
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
129+
prefix="arm_tosa_"
130+
)
131+
else:
132+
intermediate_path = custom_path
133+
105134
if not os.path.exists(intermediate_path):
106135
os.makedirs(intermediate_path, exist_ok=True)
107136
compile_spec_builder = (

backends/arm/test/misc/test_debug_feats.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
import os
9+
import shutil
910
import tempfile
1011
import unittest
1112

@@ -149,3 +150,39 @@ def test_dump_ops_and_dtypes(self):
149150
.dump_operator_distribution()
150151
)
151152
# Just test that there are no execeptions.
153+
154+
155+
class TestCollateTosaTests(unittest.TestCase):
156+
"""Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""
157+
158+
def test_collate_tosa_BI_tests(self):
159+
# Set the environment variable to trigger the collation of TOSA tests
160+
os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests"
161+
# Clear out the directory
162+
163+
model = Linear(20, 30)
164+
(
165+
ArmTester(
166+
model,
167+
example_inputs=model.get_inputs(),
168+
compile_spec=common.get_tosa_compile_spec(),
169+
)
170+
.quantize()
171+
.export()
172+
.to_edge()
173+
.partition()
174+
.to_executorch()
175+
)
176+
# test that the output directory is created and contains the expected files
177+
assert os.path.exists(
178+
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
179+
)
180+
assert os.path.exists(
181+
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag8.tosa"
182+
)
183+
assert os.path.exists(
184+
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag8.json"
185+
)
186+
187+
os.environ.pop("TOSA_TESTCASES_BASE_PATH")
188+
shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)

0 commit comments

Comments
 (0)