Skip to content

Commit 3c79c7f

Browse files
committed
Parse either cases
1 parent 645451f commit 3c79c7f

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

array_api_tests/test_special_cases.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,14 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
306306

307307
r_not = re.compile("not (?:equal to )?(.+)")
308308
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
309+
r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}")
309310
r_gt = re.compile(f"greater than {r_code.pattern}")
310311
r_lt = re.compile(f"less than {r_code.pattern}")
311-
r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}")
312312

313313

314-
class ValueCondFactory(NamedTuple): # TODO: inherit from CondFactory as well
315-
input_: Union[Literal["i1"], Literal["i2"]]
314+
@dataclass
315+
class ValueCondFactory(CondFactory):
316+
input_: Union[Literal["i1"], Literal["i2"], Literal["either"]]
316317
groups_i: int
317318

318319
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
@@ -327,6 +328,7 @@ def cond(i1: float, i2: float) -> bool:
327328
return _cond(i1)
328329

329330
else:
331+
assert self.input_ == "i2" # sanity check
330332

331333
def cond(i1: float, i2: float) -> bool:
332334
_cond = cond_factory(i1)
@@ -383,11 +385,16 @@ def cond(i1: float, i2: float) -> bool:
383385
def cond(i1: float, i2: float) -> bool:
384386
return final_cond(i1)
385387

386-
else:
388+
elif self.input_ == "i2":
387389

388390
def cond(i1: float, i2: float) -> bool:
389391
return final_cond(i2)
390392

393+
else:
394+
395+
def cond(i1: float, i2: float) -> bool:
396+
return final_cond(i1) or final_cond(i2)
397+
391398
return cond
392399

393400

@@ -490,12 +497,12 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
490497
# re.compile(
491498
# r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)"
492499
# ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq),
493-
# re.compile(
494-
# "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
495-
# ): make_bin_and_factory(make_eq, make_eq),
496-
# re.compile(
497-
# "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
498-
# ): make_bin_or_factory(make_eq),
500+
re.compile(
501+
"If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
502+
): BinaryCaseFactory(
503+
ValueCondFactory("either", 0),
504+
ResultCheckFactory(1),
505+
),
499506
# re.compile(
500507
# "If ``x1_i`` and ``x2_i`` have the same mathematical sign, "
501508
# "the result has a (.+)"
@@ -516,7 +523,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
516523
# ): lambda v: (
517524
# lambda i1, i2: diff_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2)
518525
# ),
519-
# TODO: support capturing values that come after the result
520526
# re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined")
521527
# re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined")
522528
}
@@ -676,3 +682,9 @@ def test_binary(func_name, func, cases, x1, x2):
676682
# )
677683
break
678684
assume(good_example)
685+
686+
687+
# TODO: remove
688+
print(
689+
f"no. of cases={sum(len(cases) for _, _, cases in binary_params)}"
690+
)

0 commit comments

Comments
 (0)