Skip to content

Commit 0121dae

Browse files
Arm backend: Update parse_test_name script (#10902)
- Make TARGETS a list to preserve order. - Replace edge_ops with a dict mapping test op names to edge op names, e.g. op_name_map["split_tensor"] = "split_copy.Tensor" Both changes are slightly less efficient compared to using sets, but more convenient to use when handling the results. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 9dece67 commit 0121dae

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

backends/arm/scripts/parse_test_names.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@
1818
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
1919

2020
# Add all targets and TOSA profiles we support here.
21-
TARGETS = {"tosa_BI", "tosa_MI", "u55_BI", "u85_BI"}
21+
TARGETS = ["tosa_MI", "tosa_BI", "u55_BI", "u85_BI"]
2222

2323

24-
def get_edge_ops():
24+
def get_op_name_map():
2525
"""
26-
Returns a set with edge_ops with names on the form to be used in unittests:
26+
Returns a mapping from names on the form to be used in unittests to edge op:
2727
1. Names are in lowercase.
28-
2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
28+
2. Overload is ignored if 'default', otherwise it's appended with an underscore.
2929
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
3030
3131
Examples:
3232
abs.default -> abs
3333
split_copy.Tensor -> split_tensor
3434
"""
35-
edge_ops = set()
35+
map = {}
3636
for edge_name in ALL_EDGE_OPS:
3737
op, overload = edge_name.split(".")
3838

@@ -45,21 +45,24 @@ def get_edge_ops():
4545
overload = overload.lower()
4646

4747
if overload == "default":
48-
edge_ops.add(op)
48+
map[op] = edge_name
4949
else:
50-
edge_ops.add(f"{op}_{overload}")
50+
map[f"{op}_{overload}"] = edge_name
5151

52-
return edge_ops
52+
return map
5353

5454

55-
def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]:
55+
def parse_test_name(
56+
test_name: str, op_name_map: dict[str, str]
57+
) -> tuple[str, str, bool]:
5658
"""
5759
Parses a test name on the form
5860
test_OP_TARGET_<not_delegated>_<any_other_info>
59-
where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
60-
The "not_delegated" suffix indicates that the test tests that the op is not delegated.
61+
where OP must match a key in op_name_map and TARGET one string in TARGETS. The
62+
"not_delegated" suffix indicates that the test tests that the op is not delegated.
6163
62-
Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
64+
Examples of valid names: "test_mm_u55_BI_not_delegated" and
65+
"test_add_scalar_tosa_MI_two_inputs".
6366
6467
Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
6568
"""
@@ -83,7 +86,7 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
8386

8487
assert target != "None", f"{test_name} does not contain one of {TARGETS}"
8588
assert (
86-
op in edge_ops
89+
op in op_name_map.keys()
8790
), f"Parsed unvalid OP from {test_name}, {op} does not exist in edge.yaml or CUSTOM_EDGE_OPS"
8891

8992
return op, target, is_delegated
@@ -95,13 +98,13 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
9598

9699
sys.tracebacklimit = 0 # Do not print stack trace
97100

98-
edge_ops = get_edge_ops()
101+
op_name_map = get_op_name_map()
99102
exit_code = 0
100103

101104
for test_name in sys.argv[1:]:
102105
try:
103106
assert test_name[:5] == "test_", f"Unexpected input: {test_name}"
104-
parse_test_name(test_name, edge_ops)
107+
parse_test_name(test_name, op_name_map)
105108
except AssertionError as e:
106109
print(e)
107110
exit_code = 1

0 commit comments

Comments
 (0)