Skip to content

Commit 0fb1a94

Browse files
committed
Cover most sign special cases
1 parent 17d06f9 commit 0fb1a94

File tree

1 file changed

+63
-25
lines changed

1 file changed

+63
-25
lines changed

array_api_tests/test_special_cases.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
# ------------------------------------------------------------------------------
2121

2222

23-
def make_eq(v: float) -> Callable[[float], bool]:
23+
UnaryCheck = Callable[[float], bool]
24+
BinaryCheck = Callable[[float, float], bool]
25+
26+
27+
def make_eq(v: float) -> UnaryCheck:
2428
if math.isnan(v):
2529
return math.isnan
2630
if v == 0:
@@ -35,7 +39,7 @@ def eq(i: float) -> bool:
3539
return eq
3640

3741

38-
def make_neq(v: float) -> Callable[[float], bool]:
42+
def make_neq(v: float) -> UnaryCheck:
3943
eq = make_eq(v)
4044

4145
def neq(i: float) -> bool:
@@ -44,7 +48,7 @@ def neq(i: float) -> bool:
4448
return neq
4549

4650

47-
def make_rough_eq(v: float) -> Callable[[float], bool]:
51+
def make_rough_eq(v: float) -> UnaryCheck:
4852
assert math.isfinite(v) # sanity check
4953

5054
def rough_eq(i: float) -> bool:
@@ -53,40 +57,42 @@ def rough_eq(i: float) -> bool:
5357
return rough_eq
5458

5559

56-
def make_gt(v: float):
60+
def make_gt(v: float) -> UnaryCheck:
5761
assert not math.isnan(v) # sanity check
5862

59-
def gt(i: float):
63+
def gt(i: float) -> bool:
6064
return i > v
6165

6266
return gt
6367

6468

65-
def make_lt(v: float):
69+
def make_lt(v: float) -> UnaryCheck:
6670
assert not math.isnan(v) # sanity check
6771

68-
def lt(i: float):
72+
def lt(i: float) -> bool:
6973
return i < v
7074

7175
return lt
7276

7377

74-
def make_or(cond1: Callable, cond2: Callable):
75-
def or_(i: float):
78+
def make_or(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
79+
def or_(i: float) -> bool:
7680
return cond1(i) or cond2(i)
7781

7882
return or_
7983

8084

81-
def make_and(cond1: Callable, cond2: Callable) -> Callable:
85+
def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
8286
def and_(i: float) -> bool:
8387
return cond1(i) or cond2(i)
8488

8589
return and_
8690

8791

88-
def make_bin_and_factory(make_cond1: Callable, make_cond2: Callable) -> Callable:
89-
def make_bin_and(v1: float, v2: float) -> Callable:
92+
def make_bin_and_factory(
93+
make_cond1: Callable[[float], UnaryCheck], make_cond2: Callable[[float], UnaryCheck]
94+
) -> Callable[[float, float], BinaryCheck]:
95+
def make_bin_and(v1: float, v2: float) -> BinaryCheck:
9096
cond1 = make_cond1(v1)
9197
cond2 = make_cond2(v2)
9298

@@ -98,8 +104,10 @@ def bin_and(i1: float, i2: float) -> bool:
98104
return make_bin_and
99105

100106

101-
def make_bin_or_factory(make_cond: Callable) -> Callable:
102-
def make_bin_or(v: float) -> Callable:
107+
def make_bin_or_factory(
108+
make_cond: Callable[[float], UnaryCheck]
109+
) -> Callable[[float], BinaryCheck]:
110+
def make_bin_or(v: float) -> BinaryCheck:
103111
cond = make_cond(v)
104112

105113
def bin_or(i1: float, i2: float) -> bool:
@@ -110,8 +118,10 @@ def bin_or(i1: float, i2: float) -> bool:
110118
return make_bin_or
111119

112120

113-
def absify_cond_factory(make_cond):
114-
def make_abs_cond(v: float):
121+
def absify_cond_factory(
122+
make_cond: Callable[[float], UnaryCheck]
123+
) -> Callable[[float], UnaryCheck]:
124+
def make_abs_cond(v: float) -> UnaryCheck:
115125
cond = make_cond(v)
116126

117127
def abs_cond(i: float) -> bool:
@@ -124,9 +134,10 @@ def abs_cond(i: float) -> bool:
124134

125135

126136
def make_bin_multi_and_factory(
127-
make_conds1: List[Callable], make_conds2: List[Callable]
137+
make_conds1: List[Callable[[float], UnaryCheck]],
138+
make_conds2: List[Callable[[float], UnaryCheck]],
128139
) -> Callable:
129-
def make_bin_multi_and(*values: float) -> Callable:
140+
def make_bin_multi_and(*values: float) -> BinaryCheck:
130141
assert len(values) == len(make_conds1) + len(make_conds2)
131142
conds1 = [make_cond(v) for make_cond, v in zip(make_conds1, values)]
132143
conds2 = [make_cond(v) for make_cond, v in zip(make_conds2, values[::-1])]
@@ -139,6 +150,14 @@ def bin_multi_and(i1: float, i2: float) -> bool:
139150
return make_bin_multi_and
140151

141152

153+
def same_sign(i1: float, i2: float) -> bool:
154+
return math.copysign(1, i1) == math.copysign(1, i2)
155+
156+
157+
def diff_sign(i1: float, i2: float) -> bool:
158+
return not same_sign(i1, i2)
159+
160+
142161
# Parse utils
143162
# ------------------------------------------------------------------------------
144163

@@ -271,6 +290,9 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
271290

272291

273292
binary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
293+
re.compile(
294+
"If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
295+
): lambda v: lambda _, i2: make_eq(v)(i2),
274296
re.compile(
275297
"If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)"
276298
): make_bin_and_factory(make_eq, lambda v: lambda i: i != v),
@@ -355,13 +377,29 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
355377
and make_or(make_eq(v3), make_eq(v4))(i2)
356378
)
357379
),
358-
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+)")
359-
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
360-
# re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign and are both (.+), the result has a (.+)")
361-
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+)")
362-
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined")
363-
# re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs and are both (.+), the result has a (.+)")
364-
# re.compile("If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+")
380+
re.compile(
381+
"If ``x1_i`` and ``x2_i`` have the same mathematical sign, "
382+
"the result has a (.+)"
383+
): lambda: same_sign,
384+
re.compile(
385+
"If ``x1_i`` and ``x2_i`` have different mathematical signs, "
386+
"the result has a (.+)"
387+
): lambda: diff_sign,
388+
re.compile(
389+
"If ``x1_i`` and ``x2_i`` have the same mathematical sign and "
390+
"are both (.+), the result has a (.+)"
391+
): lambda v: lambda i1, i2: same_sign(i1, i2)
392+
and make_eq(v)(i1)
393+
and make_eq(v)(i2),
394+
re.compile(
395+
"If ``x1_i`` and ``x2_i`` have different mathematical signs and "
396+
"are both (.+), the result has a (.+)"
397+
): lambda v: lambda i1, i2: diff_sign(i1, i2)
398+
and make_eq(v)(i1)
399+
and make_eq(v)(i2),
400+
# TODO: support capturing values that come after the result
401+
# re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined")
402+
# re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined")
365403
}
366404

367405

0 commit comments

Comments
 (0)