Skip to content

Commit 86a9756

Browse files
committed
Factories for result check functions
1 parent 0334387 commit 86a9756

File tree

1 file changed

+59
-23
lines changed

1 file changed

+59
-23
lines changed

array_api_tests/test_special_cases.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
305305

306306

307307
r_not_code = re.compile(f"not (?:equal to )?{r_code.pattern}")
308-
r_array_element = re.compile(r"``([+-]?)x[12]_i``")
308+
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
309309
r_gt = re.compile(f"greater than {r_code.pattern}")
310310
r_lt = re.compile(f"less than {r_code.pattern}")
311311
r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}")
@@ -333,8 +333,8 @@ def cond(i1: float, i2: float) -> bool:
333333
return _cond(i2)
334334

335335
return cond
336-
# this branch must come after checking for array elements
337-
elif m := r_code.match(group):
336+
337+
if m := r_code.match(group):
338338
value = parse_value(m.group(1))
339339
_cond = make_eq(value)
340340
elif m := r_not_code.match(group):
@@ -398,39 +398,75 @@ def cond(i1: float, i2: float) -> bool:
398398

399399
return cond
400400

401+
def __repr__(self) -> str:
402+
f_cond_factories = ", ".join(
403+
repr(cond_factory) for cond_factory in self.cond_factories
404+
)
405+
return f"{self.__class__.__name__}({f_cond_factories})"
401406

402-
class BinaryCase(NamedTuple):
403-
cond: BinaryCheck
404-
check_result: Callable[[float], bool]
405407

408+
BinaryResultCheck = Callable[[float, float, float], bool]
406409

407-
class BinaryCaseFactory(NamedTuple):
408-
cond_factory: CondFactory
409-
result_re_group: int
410410

411-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
412-
in_cond = self.cond_factory(groups)
411+
class ResultCheckFactory(NamedTuple):
412+
re_group: int
413+
414+
def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
415+
group = groups[self.re_group]
416+
417+
if m := r_array_element.match(group):
418+
cond_factory = make_eq if m.group(1) != "-" else make_neq
419+
420+
if m.group(2) == "1":
421+
422+
def cond(i1: float, i2: float, result: float) -> bool:
423+
_cond = cond_factory(i1)
424+
return _cond(result)
425+
426+
else:
413427

414-
s_result = groups[self.result_re_group]
415-
if m := r_array_element.match(s_result):
416-
raise ValueParseError(s_result) # TODO
417-
elif m := r_code.match(s_result):
428+
def cond(i1: float, i2: float, result: float) -> bool:
429+
_cond = cond_factory(i2)
430+
return _cond(result)
431+
432+
return cond
433+
434+
if m := r_code.match(group):
418435
value = parse_value(m.group(1))
419-
out_cond = make_eq(value)
420-
elif m := r_approx_value.match(s_result):
436+
_cond = make_eq(value)
437+
elif m := r_approx_value.match(group):
421438
value = parse_value(m.group(1))
422-
out_cond = make_rough_eq(value)
439+
_cond = make_rough_eq(value)
423440
else:
424-
raise ValueParseError(s_result)
441+
raise ValueParseError(group)
442+
443+
def cond(i1: float, i2: float, result: float) -> bool:
444+
return _cond(result)
445+
446+
return cond
425447

426-
return BinaryCase(in_cond, out_cond)
448+
449+
class BinaryCase(NamedTuple):
450+
cond: BinaryCheck
451+
check_result: BinaryResultCheck
452+
453+
454+
class BinaryCaseFactory(NamedTuple):
455+
cond_factory: CondFactory
456+
check_result_factory: ResultCheckFactory
457+
458+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
459+
cond = self.cond_factory(groups)
460+
check_result = self.check_result_factory(groups)
461+
return BinaryCase(cond, check_result)
427462

428463

429464
binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = {
430465
re.compile(
431466
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
432467
): BinaryCaseFactory(
433-
AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), 2
468+
AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)),
469+
ResultCheckFactory(2),
434470
),
435471
# re.compile(
436472
# "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
@@ -671,8 +707,8 @@ def test_binary(func_name, func, cases, x1, x2):
671707
for case in cases:
672708
if case.cond(l, r):
673709
good_example = True
674-
out = float(res[o_idx])
675-
assert case.check_result(out)
710+
o = float(res[o_idx])
711+
assert case.check_result(l, r, o)
676712
# f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
677713
# f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
678714
# f_out = f"{sh.fmt_idx('out', o_idx)}={out}"

0 commit comments

Comments
 (0)