Skip to content

Commit 0acee7c

Browse files
committed
Generously capture all special cases from stubs
1 parent b8bb10f commit 0acee7c

File tree

1 file changed

+75
-80
lines changed

1 file changed

+75
-80
lines changed

array_api_tests/test_special_cases.py

Lines changed: 75 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ def __repr__(self) -> str:
526526
return f"{self.__class__.__name__}(<{self}>)"
527527

528528

529+
r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
530+
r_case = re.compile(r"\s+-\s*(.*)\.")
531+
532+
529533
class UnaryCond(Protocol):
530534
def __call__(self, i: float) -> bool:
531535
...
@@ -586,7 +590,7 @@ def check_result(i: float, result: float) -> bool:
586590
return check_result
587591

588592

589-
def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
593+
def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
590594
"""
591595
Parses a Sphinx-formatted docstring of a unary function to return a list of
592596
codified unary cases, e.g.
@@ -616,7 +620,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
616620
... an array containing the square root of each element in ``x``
617621
... '''
618622
...
619-
>>> unary_cases = parse_unary_docstring(sqrt.__doc__)
623+
>>> case_block = r_case_block.match(sqrt.__doc__).group(1)
624+
>>> unary_cases = parse_unary_case_block(case_block)
620625
>>> for case in unary_cases:
621626
... print(repr(case))
622627
UnaryCase(<x_i < 0 -> NaN>)
@@ -631,19 +636,10 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
631636
True
632637
633638
"""
634-
635-
match = r_special_cases.search(docstring)
636-
if match is None:
637-
return []
638-
lines = match.group(1).split("\n")[:-1]
639639
cases = []
640-
for line in lines:
641-
if m := r_case.match(line):
642-
case = m.group(1)
643-
else:
644-
warn(f"line not machine-readable: '{line}'")
645-
continue
646-
if m := r_unary_case.search(case):
640+
for case_m in r_case.finditer(case_block):
641+
case_str = case_m.group(1)
642+
if m := r_unary_case.search(case_str):
647643
try:
648644
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
649645
_check_result, result_expr = parse_result(m.group(2))
@@ -662,11 +658,11 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
662658
check_result=check_result,
663659
)
664660
cases.append(case)
665-
elif m := r_even_round_halves_case.search(case):
661+
elif m := r_even_round_halves_case.search(case_str):
666662
cases.append(even_round_halves_case)
667663
else:
668-
if not r_remaining_case.search(case):
669-
warn(f"case not machine-readable: '{case}'")
664+
if not r_remaining_case.search(case_str):
665+
warn(f"case not machine-readable: '{case_str}'")
670666
return cases
671667

672668

@@ -690,12 +686,6 @@ class BinaryCase(Case):
690686
check_result: BinaryResultCheck
691687

692688

693-
r_special_cases = re.compile(
694-
r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
695-
r"For floating-point operands,\n+"
696-
r"((?:\s*-\s*.*\n)+)"
697-
)
698-
r_case = re.compile(r"\s+-\s*(.*)\.\n?")
699689
r_binary_case = re.compile("If (.+), the result (.+)")
700690
r_remaining_case = re.compile("In the remaining cases.+")
701691
r_cond_sep = re.compile(r"(?<!``x1_i``),? and |(?<!i\.e\.), ")
@@ -880,8 +870,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:
880870
881871
"""
882872
case_m = r_binary_case.match(case_str)
883-
if case_m is None:
884-
raise ParseError(case_str)
873+
assert case_m is not None # sanity check
885874
cond_strs = r_cond_sep.split(case_m.group(1))
886875

887876
partial_conds = []
@@ -1078,7 +1067,7 @@ def cond(i1: float, i2: float) -> bool:
10781067
r_redundant_case = re.compile("result.+determined by the rule already stated above")
10791068

10801069

1081-
def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1070+
def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
10821071
"""
10831072
Parses a Sphinx-formatted docstring of a binary function to return a list of
10841073
codified binary cases, e.g.
@@ -1108,29 +1097,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11081097
... an array containing the results
11091098
... '''
11101099
...
1111-
>>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
1100+
>>> case_block = r_case_block.match(logaddexp.__doc__).group(1)
1101+
>>> binary_cases = parse_binary_case_block(case_block)
11121102
>>> for case in binary_cases:
11131103
... print(repr(case))
11141104
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
11151105
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
11161106
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
11171107
11181108
"""
1119-
1120-
match = r_special_cases.search(docstring)
1121-
if match is None:
1122-
return []
1123-
lines = match.group(1).split("\n")[:-1]
11241109
cases = []
1125-
for line in lines:
1126-
if m := r_case.match(line):
1127-
case_str = m.group(1)
1128-
else:
1129-
warn(f"line not machine-readable: '{line}'")
1130-
continue
1110+
for case_m in r_case.finditer(case_block):
1111+
case_str = case_m.group(1)
11311112
if r_redundant_case.search(case_str):
11321113
continue
1133-
if m := r_binary_case.match(case_str):
1114+
if r_binary_case.match(case_str):
11341115
try:
11351116
case = parse_binary_case(case_str)
11361117
cases.append(case)
@@ -1142,14 +1123,19 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11421123
return cases
11431124

11441125

1126+
category_stub_pairs = [(c, s) for c, stubs in category_to_funcs.items() for s in stubs]
11451127
unary_params = []
11461128
binary_params = []
11471129
iop_params = []
11481130
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
1149-
for stub in category_to_funcs["elementwise"]:
1131+
for category, stub in category_stub_pairs:
11501132
if stub.__doc__ is None:
11511133
warn(f"{stub.__name__}() stub has no docstring")
11521134
continue
1135+
if m := r_case_block.search(stub.__doc__):
1136+
case_block = m.group(1)
1137+
else:
1138+
continue
11531139
marks = []
11541140
try:
11551141
func = getattr(xp, stub.__name__)
@@ -1163,47 +1149,56 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11631149
if len(sig.parameters) == 0:
11641150
warn(f"{func=} has no parameters")
11651151
continue
1166-
if param_names[0] == "x":
1167-
if cases := parse_unary_docstring(stub.__doc__):
1168-
name_to_func = {stub.__name__: func}
1169-
if stub.__name__ in func_to_op.keys():
1170-
op_name = func_to_op[stub.__name__]
1171-
op = getattr(operator, op_name)
1172-
name_to_func[op_name] = op
1173-
for func_name, func in name_to_func.items():
1174-
for case in cases:
1175-
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1176-
p = pytest.param(func_name, func, case, id=id_)
1177-
unary_params.append(p)
1178-
continue
1179-
if len(sig.parameters) == 1:
1180-
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
1181-
continue
1182-
if param_names[0] == "x1" and param_names[1] == "x2":
1183-
if cases := parse_binary_docstring(stub.__doc__):
1184-
name_to_func = {stub.__name__: func}
1185-
if stub.__name__ in func_to_op.keys():
1186-
op_name = func_to_op[stub.__name__]
1187-
op = getattr(operator, op_name)
1188-
name_to_func[op_name] = op
1189-
# We collect inplaceoperator test cases seperately
1190-
iop_name = "__i" + op_name[2:]
1191-
iop = getattr(operator, iop_name)
1192-
for case in cases:
1193-
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
1194-
p = pytest.param(iop_name, iop, case, id=id_)
1195-
iop_params.append(p)
1196-
for func_name, func in name_to_func.items():
1197-
for case in cases:
1198-
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1199-
p = pytest.param(func_name, func, case, id=id_)
1200-
binary_params.append(p)
1201-
continue
1152+
if category == "elementwise":
1153+
if param_names[0] == "x":
1154+
if cases := parse_unary_case_block(case_block):
1155+
name_to_func = {stub.__name__: func}
1156+
if stub.__name__ in func_to_op.keys():
1157+
op_name = func_to_op[stub.__name__]
1158+
op = getattr(operator, op_name)
1159+
name_to_func[op_name] = op
1160+
for func_name, func in name_to_func.items():
1161+
for case in cases:
1162+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1163+
p = pytest.param(func_name, func, case, id=id_)
1164+
unary_params.append(p)
1165+
else:
1166+
warn("TODO")
1167+
continue
1168+
if len(sig.parameters) == 1:
1169+
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
1170+
continue
1171+
if param_names[0] == "x1" and param_names[1] == "x2":
1172+
if cases := parse_binary_case_block(case_block):
1173+
name_to_func = {stub.__name__: func}
1174+
if stub.__name__ in func_to_op.keys():
1175+
op_name = func_to_op[stub.__name__]
1176+
op = getattr(operator, op_name)
1177+
name_to_func[op_name] = op
1178+
# We collect inplace operator test cases seperately
1179+
iop_name = "__i" + op_name[2:]
1180+
iop = getattr(operator, iop_name)
1181+
for case in cases:
1182+
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
1183+
p = pytest.param(iop_name, iop, case, id=id_)
1184+
iop_params.append(p)
1185+
for func_name, func in name_to_func.items():
1186+
for case in cases:
1187+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1188+
p = pytest.param(func_name, func, case, id=id_)
1189+
binary_params.append(p)
1190+
else:
1191+
warn("TODO")
1192+
continue
1193+
else:
1194+
warn(
1195+
f"{func=} starts with two parameters '{param_names[0]}' and "
1196+
f"'{param_names[1]}', which are not named 'x1' and 'x2'"
1197+
)
1198+
elif category == "statistical":
1199+
pass # TODO
12021200
else:
1203-
warn(
1204-
f"{func=} starts with two parameters '{param_names[0]}' and "
1205-
f"'{param_names[1]}', which are not named 'x1' and 'x2'"
1206-
)
1201+
warn("TODO")
12071202

12081203

12091204
# test_unary and test_binary naively generate arrays, i.e. arrays that might not

0 commit comments

Comments
 (0)