Skip to content

Commit f3f683f

Browse files
committed
Parse sign cases
1 parent 3c79c7f commit f3f683f

File tree

1 file changed

+65
-36
lines changed

1 file changed

+65
-36
lines changed

array_api_tests/test_special_cases.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,11 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
313313

314314
@dataclass
315315
class ValueCondFactory(CondFactory):
316-
input_: Union[Literal["i1"], Literal["i2"], Literal["either"]]
317-
groups_i: int
316+
input_: Union[Literal["i1"], Literal["i2"], Literal["either"], Literal["both"]]
317+
re_groups_i: int
318318

319319
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
320-
group = groups[self.groups_i]
320+
group = groups[self.re_groups_i]
321321

322322
if m := r_array_element.match(group):
323323
cond_factory = make_eq if m.group(1) != "-" else make_neq
@@ -390,11 +390,16 @@ def cond(i1: float, i2: float) -> bool:
390390
def cond(i1: float, i2: float) -> bool:
391391
return final_cond(i2)
392392

393-
else:
393+
elif self.input_ == "either":
394394

395395
def cond(i1: float, i2: float) -> bool:
396396
return final_cond(i1) or final_cond(i2)
397397

398+
else:
399+
400+
def cond(i1: float, i2: float) -> bool:
401+
return final_cond(i1) and final_cond(i2)
402+
398403
return cond
399404

400405

@@ -417,14 +422,28 @@ def __repr__(self) -> str:
417422
return f"{self.__class__.__name__}({f_cond_factories})"
418423

419424

425+
@dataclass
426+
class SignCondFactory(CondFactory):
427+
re_groups_i: int
428+
429+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
430+
group = groups[self.re_groups_i]
431+
if group == "the same mathematical sign":
432+
return same_sign
433+
elif group == "different mathematical signs":
434+
return diff_sign
435+
else:
436+
raise ValueParseError(group)
437+
438+
420439
BinaryResultCheck = Callable[[float, float, float], bool]
421440

422441

423442
class ResultCheckFactory(NamedTuple):
424-
groups_i: int
443+
re_groups_i: int
425444

426445
def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
427-
group = groups[self.groups_i]
446+
group = groups[self.re_groups_i]
428447

429448
if m := r_array_element.match(group):
430449
cond_factory = make_eq if m.group(1) != "-" else make_neq
@@ -458,6 +477,29 @@ def cond(i1: float, i2: float, result: float) -> bool:
458477
return cond
459478

460479

480+
class ResultSignCheckFactory(ResultCheckFactory):
481+
def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
482+
group = groups[self.re_groups_i]
483+
if group == "positive":
484+
485+
def cond(i1: float, i2: float, result: float) -> bool:
486+
if math.isnan(result):
487+
return True
488+
return result > 0 or ph.is_pos_zero(result)
489+
490+
elif group == "negative":
491+
492+
def cond(i1: float, i2: float, result: float) -> bool:
493+
if math.isnan(result):
494+
return True
495+
return result < 0 or ph.is_neg_zero(result)
496+
497+
else:
498+
raise ValueParseError(group)
499+
500+
return cond
501+
502+
461503
class BinaryCase(NamedTuple):
462504
cond: BinaryCheck
463505
check_result: BinaryResultCheck
@@ -473,6 +515,8 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
473515
return BinaryCase(cond, check_result)
474516

475517

518+
r_result_sign = re.compile("([a-z]+) mathematical sign")
519+
476520
binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = {
477521
re.compile(
478522
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
@@ -499,32 +543,23 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
499543
# ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq),
500544
re.compile(
501545
"If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
546+
): BinaryCaseFactory(ValueCondFactory("either", 0), ResultCheckFactory(1)),
547+
re.compile(
548+
"If ``x1_i`` and ``x2_i`` have (.+signs?), "
549+
f"the result has a {r_result_sign.pattern}"
550+
): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)),
551+
re.compile(
552+
"If ``x1_i`` and ``x2_i`` have (.+signs?) and are both (.+), "
553+
f"the result has a {r_result_sign.pattern}"
502554
): BinaryCaseFactory(
503-
ValueCondFactory("either", 0),
504-
ResultCheckFactory(1),
555+
AndCondFactory(SignCondFactory(0), ValueCondFactory("both", 1)),
556+
ResultSignCheckFactory(2),
505557
),
506-
# re.compile(
507-
# "If ``x1_i`` and ``x2_i`` have the same mathematical sign, "
508-
# "the result has a (.+)"
509-
# ): lambda: same_sign,
510-
# re.compile(
511-
# "If ``x1_i`` and ``x2_i`` have different mathematical signs, "
512-
# "the result has a (.+)"
513-
# ): lambda: diff_sign,
514-
# re.compile(
515-
# "If ``x1_i`` and ``x2_i`` have the same mathematical sign and "
516-
# "are both (.+), the result has a (.+)"
517-
# ): lambda v: (
518-
# lambda i1, i2: same_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2)
519-
# ),
520-
# re.compile(
521-
# "If ``x1_i`` and ``x2_i`` have different mathematical signs and "
522-
# "are both (.+), the result has a (.+)"
523-
# ): lambda v: (
524-
# lambda i1, i2: diff_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2)
525-
# ),
526-
# re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined")
527-
# re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined")
558+
re.compile(
559+
"If ``x1_i`` and ``x2_i`` have (.+signs?), the result has a "
560+
rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result "
561+
r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\."
562+
): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)),
528563
}
529564

530565

@@ -682,9 +717,3 @@ def test_binary(func_name, func, cases, x1, x2):
682717
# )
683718
break
684719
assume(good_example)
685-
686-
687-
# TODO: remove
688-
print(
689-
f"no. of cases={sum(len(cases) for _, _, cases in binary_params)}"
690-
)

0 commit comments

Comments
 (0)