18
18
ALL_EDGE_OPS = SAMPLE_INPUT .keys () | CUSTOM_EDGE_OPS
19
19
20
20
# Add all targets and TOSA profiles we support here.
21
- TARGETS = { "tosa_BI " , "tosa_MI " , "u55_BI" , "u85_BI" }
21
+ TARGETS = [ "tosa_MI " , "tosa_BI " , "u55_BI" , "u85_BI" ]
22
22
23
23
24
- def get_edge_ops ():
24
+ def get_op_name_map ():
25
25
"""
26
- Returns a set with edge_ops with names on the form to be used in unittests:
26
+ Returns a mapping from names on the form to be used in unittests to edge op :
27
27
1. Names are in lowercase.
28
- 2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
28
+ 2. Overload is ignored if 'default', otherwise it's appended with an underscore.
29
29
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
30
30
31
31
Examples:
32
32
abs.default -> abs
33
33
split_copy.Tensor -> split_tensor
34
34
"""
35
- edge_ops = set ()
35
+ map = {}
36
36
for edge_name in ALL_EDGE_OPS :
37
37
op , overload = edge_name .split ("." )
38
38
@@ -45,21 +45,24 @@ def get_edge_ops():
45
45
overload = overload .lower ()
46
46
47
47
if overload == "default" :
48
- edge_ops . add ( op )
48
+ map [ op ] = edge_name
49
49
else :
50
- edge_ops . add ( f"{ op } _{ overload } " )
50
+ map [ f"{ op } _{ overload } " ] = edge_name
51
51
52
- return edge_ops
52
+ return map
53
53
54
54
55
- def parse_test_name (test_name : str , edge_ops : set [str ]) -> tuple [str , str , bool ]:
55
+ def parse_test_name (
56
+ test_name : str , op_name_map : dict [str , str ]
57
+ ) -> tuple [str , str , bool ]:
56
58
"""
57
59
Parses a test name on the form
58
60
test_OP_TARGET_<not_delegated>_<any_other_info>
59
- where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
60
- The "not_delegated" suffix indicates that the test tests that the op is not delegated.
61
+ where OP must match a key in op_name_map and TARGET one string in TARGETS. The
62
+ "not_delegated" suffix indicates that the test tests that the op is not delegated.
61
63
62
- Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
64
+ Examples of valid names: "test_mm_u55_BI_not_delegated" and
65
+ "test_add_scalar_tosa_MI_two_inputs".
63
66
64
67
Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
65
68
"""
@@ -83,7 +86,7 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
83
86
84
87
assert target != "None" , f"{ test_name } does not contain one of { TARGETS } "
85
88
assert (
86
- op in edge_ops
89
+ op in op_name_map . keys ()
87
90
), f"Parsed unvalid OP from { test_name } , { op } does not exist in edge.yaml or CUSTOM_EDGE_OPS"
88
91
89
92
return op , target , is_delegated
@@ -95,13 +98,13 @@ def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]
95
98
96
99
sys .tracebacklimit = 0 # Do not print stack trace
97
100
98
- edge_ops = get_edge_ops ()
101
+ op_name_map = get_op_name_map ()
99
102
exit_code = 0
100
103
101
104
for test_name in sys .argv [1 :]:
102
105
try :
103
106
assert test_name [:5 ] == "test_" , f"Unexpected input: { test_name } "
104
- parse_test_name (test_name , edge_ops )
107
+ parse_test_name (test_name , op_name_map )
105
108
except AssertionError as e :
106
109
print (e )
107
110
exit_code = 1
0 commit comments