Skip to content

Commit 17d06f9

Browse files
committed
More binary cases coverage
1 parent 3d05447 commit 17d06f9

File tree

2 files changed

+180
-9
lines changed

2 files changed

+180
-9
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import math
2+
3+
from ..test_special_cases import parse_result
4+
5+
6+
def test_parse_result():
7+
s_result = "an implementation-dependent approximation to ``+3π/4``"
8+
assert parse_result(s_result).value == 3 * math.pi / 4

array_api_tests/test_special_cases.py

Lines changed: 172 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import math
33
import re
4-
from typing import Callable, Dict, NamedTuple, Pattern
4+
from typing import Callable, Dict, List, NamedTuple, Pattern
55
from warnings import warn
66

77
import pytest
@@ -16,6 +16,9 @@
1616
from ._array_module import mod as xp
1717
from .stubs import category_to_funcs
1818

19+
# Condition factories
20+
# ------------------------------------------------------------------------------
21+
1922

2023
def make_eq(v: float) -> Callable[[float], bool]:
2124
if math.isnan(v):
@@ -32,6 +35,15 @@ def eq(i: float) -> bool:
3235
return eq
3336

3437

38+
def make_neq(v: float) -> Callable[[float], bool]:
39+
eq = make_eq(v)
40+
41+
def neq(i: float) -> bool:
42+
return not eq(i)
43+
44+
return neq
45+
46+
3547
def make_rough_eq(v: float) -> Callable[[float], bool]:
3648
assert math.isfinite(v) # sanity check
3749

@@ -66,6 +78,71 @@ def or_(i: float):
6678
return or_
6779

6880

81+
def make_and(cond1: Callable, cond2: Callable) -> Callable:
82+
def and_(i: float) -> bool:
83+
return cond1(i) or cond2(i)
84+
85+
return and_
86+
87+
88+
def make_bin_and_factory(make_cond1: Callable, make_cond2: Callable) -> Callable:
89+
def make_bin_and(v1: float, v2: float) -> Callable:
90+
cond1 = make_cond1(v1)
91+
cond2 = make_cond2(v2)
92+
93+
def bin_and(i1: float, i2: float) -> bool:
94+
return cond1(i1) and cond2(i2)
95+
96+
return bin_and
97+
98+
return make_bin_and
99+
100+
101+
def make_bin_or_factory(make_cond: Callable) -> Callable:
102+
def make_bin_or(v: float) -> Callable:
103+
cond = make_cond(v)
104+
105+
def bin_or(i1: float, i2: float) -> bool:
106+
return cond(i1) or cond(i2)
107+
108+
return bin_or
109+
110+
return make_bin_or
111+
112+
113+
def absify_cond_factory(make_cond):
114+
def make_abs_cond(v: float):
115+
cond = make_cond(v)
116+
117+
def abs_cond(i: float) -> bool:
118+
i = abs(i)
119+
return cond(i)
120+
121+
return abs_cond
122+
123+
return make_abs_cond
124+
125+
126+
def make_bin_multi_and_factory(
127+
make_conds1: List[Callable], make_conds2: List[Callable]
128+
) -> Callable:
129+
def make_bin_multi_and(*values: float) -> Callable:
130+
assert len(values) == len(make_conds1) + len(make_conds2)
131+
conds1 = [make_cond(v) for make_cond, v in zip(make_conds1, values)]
132+
conds2 = [make_cond(v) for make_cond, v in zip(make_conds2, values[::-1])]
133+
134+
def bin_multi_and(i1: float, i2: float) -> bool:
135+
return all(cond(i1) for cond in conds1) and all(cond(i2) for cond in conds2)
136+
137+
return bin_multi_and
138+
139+
return make_bin_multi_and
140+
141+
142+
# Parse utils
143+
# ------------------------------------------------------------------------------
144+
145+
69146
repr_to_value = {
70147
"NaN": float("nan"),
71148
"infinity": float("infinity"),
@@ -183,6 +260,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
183260
result = parse_result(s_result)
184261
except ValueParseError as e:
185262
warn(f"result not machine-readable: '{e.value}'")
263+
186264
break
187265
condition_to_result[cond] = result
188266
break
@@ -193,10 +271,97 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
193271

194272

195273
binary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
274+
re.compile(
275+
"If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)"
276+
): make_bin_and_factory(make_eq, lambda v: lambda i: i != v),
277+
re.compile(
278+
"If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), "
279+
"and ``x2_i`` is (.+), the result is (.+)"
280+
): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]),
281+
re.compile(
282+
"If ``x1_i`` is less than (.+), ``x1_i`` is (.+), "
283+
"and ``x2_i`` is (.+), the result is (.+)"
284+
): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]),
285+
re.compile(
286+
"If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), "
287+
"and ``x2_i`` is not (.+), the result is (.+)"
288+
): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]),
289+
re.compile(
290+
"If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
291+
"and ``x2_i`` is (.+), the result is (.+)"
292+
): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]),
293+
re.compile(
294+
"If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
295+
"and ``x2_i`` is not (.+), the result is (.+)"
296+
): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]),
297+
re.compile(
298+
"If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
299+
"and ``x2_i`` is (.+), the result is (.+)"
300+
): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]),
301+
re.compile(
302+
"If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
303+
"and ``x2_i`` is not (.+), the result is (.+)"
304+
): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]),
305+
re.compile(
306+
"If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)"
307+
): make_bin_and_factory(make_gt, make_eq),
308+
re.compile(
309+
"If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)"
310+
): make_bin_and_factory(make_eq, make_gt),
311+
re.compile(
312+
"If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
313+
): make_bin_and_factory(make_lt, make_eq),
314+
re.compile(
315+
"If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)"
316+
): make_bin_and_factory(make_eq, make_lt),
317+
re.compile(
318+
"If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)"
319+
): make_bin_and_factory(make_neq, make_eq),
320+
re.compile(
321+
"If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)"
322+
): make_bin_and_factory(make_eq, make_neq),
323+
re.compile(
324+
r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), "
325+
"the result is (.+)"
326+
): make_bin_and_factory(absify_cond_factory(make_gt), make_eq),
327+
re.compile(
328+
r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
329+
): make_bin_and_factory(absify_cond_factory(make_lt), make_eq),
330+
re.compile(
331+
r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)"
332+
): make_bin_and_factory(absify_cond_factory(make_eq), make_eq),
196333
re.compile(
197334
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
198-
): lambda v1, v2: lambda i1, i2: make_eq(v1)(i1)
199-
and make_eq(v2)(i2),
335+
): make_bin_and_factory(make_eq, make_eq),
336+
re.compile(
337+
"If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
338+
): make_bin_or_factory(make_eq),
339+
re.compile(
340+
"If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)"
341+
): lambda v1, v2, v3: (
342+
lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2)
343+
),
344+
re.compile(
345+
"If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)"
346+
): lambda v1, v2, v3: (
347+
lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2)
348+
),
349+
re.compile(
350+
"If ``x1_i`` is either (.+) or (.+) and "
351+
"``x2_i`` is either (.+) or (.+), the result is (.+)"
352+
): lambda v1, v2, v3, v4: (
353+
lambda i1, i2: (
354+
make_or(make_eq(v1), make_eq(v2))(i1)
355+
and make_or(make_eq(v3), make_eq(v4))(i2)
356+
)
357+
),
358+
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+)")
359+
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
360+
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign and are both (.+), the result has a (.+)")
361+
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+)")
362+
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
363+
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs and are both (.+), the result has a (.+)")
364+
# re.compile("If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+")
200365
}
201366

202367

@@ -221,12 +386,6 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]:
221386
warn(f"value not machine-readable: '{e.value}'")
222387
break
223388
cond = make_cond(*values)
224-
if (
225-
"atan2" in docstring
226-
and ph.is_pos_zero(values[0])
227-
and ph.is_neg_zero(values[1])
228-
):
229-
breakpoint()
230389
try:
231390
result = parse_result(s_result)
232391
except ValueParseError as e:
@@ -240,6 +399,10 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]:
240399
return condition_to_result
241400

242401

402+
# Here be the tests
403+
# ------------------------------------------------------------------------------
404+
405+
243406
unary_params = []
244407
binary_params = []
245408
for stub in category_to_funcs["elementwise"]:

0 commit comments

Comments
 (0)