Skip to content

Commit 187511f

Browse files
committed
[ExecuTorch] Arm Ethos: Add pass tests
As title. Adds pytest.test_option["tosa_ref_model"] similar to "corestone_fvp". This is a hack. Once we buckify the reference model, we should remove this. It shouldn't have impact on the OSS test coverage. Differential Revision: [D69714010](https://our.internmc.facebook.com/intern/diff/D69714010/) ghstack-source-id: 266705920 Pull Request resolved: #8561
1 parent bb2d6b5 commit 187511f

File tree

6 files changed

+80
-13
lines changed

6 files changed

+80
-13
lines changed

backends/arm/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load(":targets.bzl", "define_arm_tests")
3+
4+
5+
oncall("executorch")
26

37
python_library(
48
name = "common",
@@ -8,6 +12,7 @@ python_library(
812
"//executorch/backends/arm:arm_backend",
913
"//executorch/exir:lib",
1014
"//executorch/exir/backend:compile_spec_schema",
15+
"fbsource//third-party/pypi/pytest:pytest",
1116
]
1217
)
1318

@@ -40,7 +45,10 @@ python_library(
4045
"//executorch/backends/arm:tosa_mapping",
4146
"//executorch/backends/arm:tosa_specification",
4247
"//executorch/backends/arm/quantizer:arm_quantizer",
48+
"//executorch/backends/arm:arm_partitioner",
4349
"//executorch/devtools/backend_debug:delegation_info",
4450
"fbsource//third-party/pypi/tabulate:tabulate",
4551
]
4652
)
53+
54+
define_arm_tests()

backends/arm/test/conftest.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from typing import Any
1414

1515
import pytest
16-
import torch
16+
import logging
17+
18+
try:
19+
import tosa_reference_model
20+
except ImportError:
21+
logging.warning("tosa_reference_model not found, can't run reference model tests")
22+
tosa_reference_model = None
1723

1824
"""
1925
This file contains the pytest hooks, fixtures etc. for the Arm test suite.
@@ -24,18 +30,29 @@
2430

2531

2632
def pytest_configure(config):
27-
2833
pytest._test_options = {} # type: ignore[attr-defined]
29-
30-
if config.option.arm_run_corstoneFVP:
34+
pytest._test_options["corstone_fvp"] = False # type: ignore[attr-defined]
35+
if (
36+
getattr(config.option, "arm_run_corestoneFVP", False)
37+
and config.option.arm_run_corstoneFVP
38+
):
3139
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
3240
corstone320_exists = shutil.which("FVP_Corstone_SSE-320")
3341
if not (corstone300_exists and corstone320_exists):
3442
raise RuntimeError(
3543
"Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed."
3644
)
37-
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]
38-
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
45+
# Only enable if we also have the TOSA reference model available.
46+
pytest._test_options["corstone_fvp"] = tosa_reference_model is not None # type: ignore[attr-defined]
47+
48+
pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined]
49+
if getattr(config.option, "fast_fvp", False):
50+
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
51+
52+
# TODO: remove this flag once we have a way to run the reference model tests with Buck
53+
pytest._test_options["tosa_ref_model"] = False # type: ignore[attr-defined]
54+
if tosa_reference_model is not None:
55+
pytest._test_options["tosa_ref_model"] = True # type: ignore[attr-defined]
3956
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
4057

4158

@@ -44,9 +61,15 @@ def pytest_collection_modifyitems(config, items):
4461

4562

4663
def pytest_addoption(parser):
47-
parser.addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
48-
parser.addoption("--arm_run_corstoneFVP", action="store_true")
49-
parser.addoption("--fast_fvp", action="store_true")
64+
def try_addoption(*args, **kwargs):
65+
try:
66+
parser.addoption(*args, **kwargs)
67+
except Exception:
68+
pass
69+
70+
try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
71+
try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.")
72+
try_addoption("--fast_fvp", action="store_true")
5073

5174

5275
def pytest_sessionstart(session):
@@ -78,6 +101,8 @@ def set_random_seed():
78101
Rerun with a specific seed found under a random seed test
79102
ARM_TEST_SEED=3478246 pytest --config-file=/dev/null --verbose -s --color=yes backends/arm/test/ops/test_avg_pool.py -k <TESTCASE>
80103
"""
104+
import torch
105+
81106
if os.environ.get("ARM_TEST_SEED", "RANDOM") == "RANDOM":
82107
random.seed() # reset seed, in case any other test has fiddled with it
83108
seed = random.randint(0, 2**32 - 1)
@@ -161,6 +186,8 @@ def _load_libquantized_ops_aot_lib():
161186
res = subprocess.run(find_lib_cmd, capture_output=True)
162187
if res.returncode == 0:
163188
library_path = res.stdout.decode().strip()
189+
import torch
190+
164191
torch.ops.load_library(library_path)
165192
else:
166193
raise RuntimeError(

backends/arm/test/passes/test_rescale_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _test_rescale_pipeline(
116116
):
117117
"""Tests a model with many ops that requires rescales. As more ops are quantized to int32 and
118118
need the InsertRescalesPass, make sure that they play nicely together."""
119-
(
119+
tester = (
120120
ArmTester(
121121
module,
122122
example_inputs=test_data,
@@ -126,8 +126,9 @@ def _test_rescale_pipeline(
126126
.export()
127127
.to_edge_transform_and_lower()
128128
.to_executorch()
129-
.run_method_and_compare_outputs(test_data)
130129
)
130+
if conftest.is_option_enabled("tosa_ref_model"):
131+
tester.run_method_and_compare_outputs(test_data)
131132

132133

133134
def _test_rescale_pipeline_ethosu(
@@ -152,6 +153,7 @@ def _test_rescale_pipeline_ethosu(
152153
class TestRescales(unittest.TestCase):
153154

154155
@parameterized.expand(RescaleNetwork.test_parameters)
156+
@pytest.mark.tosa_ref_model
155157
def test_quantized_rescale(self, x, y):
156158
_test_rescale_pipeline(RescaleNetwork(), (x, y))
157159

backends/arm/test/pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
addopts = --strict-markers
33
markers =
44
slow: Tests that take long time
5-
corstone_fvp: Tests that use Corstone300 or Corstone320 FVP
5+
corstone_fvp: Tests that use Corstone300 or Corstone320 FVP # And also uses TOSA reference model
6+
tosa_ref_model: Tests that use TOSA reference model # Temporary!

backends/arm/test/runner_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
try:
2323
import tosa_reference_model
2424
except ImportError:
25-
logger.warning("tosa_reference_model not found, can't run reference model tests")
2625
tosa_reference_model = None
2726
from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa
2827

backends/arm/test/targets.bzl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
load("//caffe2/test/fb:defs.bzl", "define_tests")
2+
load("@bazel_skylib//lib:paths.bzl", "paths")
3+
4+
def define_arm_tests():
5+
# TODO Add more tests
6+
test_files = native.glob(["passes/test_*.py"])
7+
8+
TESTS = {}
9+
10+
for test_file in test_files:
11+
test_file_name = paths.basename(test_file)
12+
test_name = test_file_name.replace("test_", "").replace(".py", "")
13+
TESTS[test_name] = [test_file]
14+
15+
define_tests(
16+
pytest = True,
17+
tests = TESTS,
18+
pytest_config = "pytest.ini",
19+
resources = ["conftest.py"],
20+
preload_deps = [
21+
"//executorch/kernels/quantized:custom_ops_generated_lib",
22+
],
23+
deps = [
24+
":arm_tester",
25+
":conftest",
26+
"//executorch/exir:lib",
27+
"fbsource//third-party/pypi/pytest:pytest",
28+
"fbsource//third-party/pypi/parameterized:parameterized",
29+
],
30+
)

0 commit comments

Comments
 (0)