Skip to content

Commit 771ea51

Browse files
committed
Cond reprs, merge abs cond logic, fix approx cond factories
1 parent b5af57d commit 771ea51

File tree

1 file changed

+113
-76
lines changed

1 file changed

+113
-76
lines changed

array_api_tests/test_special_cases.py

Lines changed: 113 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,19 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
299299
return cases
300300

301301

302-
class CondFactory(Protocol):
303-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
302+
class BinaryCond(NamedTuple):
303+
cond: BinaryCheck
304+
repr_: str
305+
306+
def __call__(self, i1: float, i2: float) -> bool:
307+
return self.cond(i1, i2)
308+
309+
def __repr__(self):
310+
return self.repr_
311+
312+
313+
class BinaryCondFactory(Protocol):
314+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
304315
...
305316

306317

@@ -310,31 +321,43 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
310321
r_gt = re.compile(f"greater than {r_code.pattern}")
311322
r_lt = re.compile(f"less than {r_code.pattern}")
312323

324+
x1_i = "x1ᵢ"
325+
x2_i = "x2ᵢ"
326+
313327

314328
@dataclass
315-
class ValueCondFactory(CondFactory):
329+
class ValueCondFactory(BinaryCondFactory):
316330
input_: Union[Literal["i1"], Literal["i2"], Literal["either"], Literal["both"]]
317331
re_groups_i: int
332+
abs_: bool = False
318333

319-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
334+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
320335
group = groups[self.re_groups_i]
321336

322337
if m := r_array_element.match(group):
323-
cond_factory = make_eq if m.group(1) != "-" else make_neq
338+
assert not self.abs_ # sanity check
339+
sign = m.group(1)
340+
if sign == "-":
341+
signer = lambda i: -i
342+
else:
343+
signer = lambda i: i
344+
324345
if self.input_ == "i1":
346+
repr_ = f"{x1_i} == {sign}{x2_i}"
325347

326348
def cond(i1: float, i2: float) -> bool:
327-
_cond = cond_factory(i2)
349+
_cond = make_eq(signer(i2))
328350
return _cond(i1)
329351

330352
else:
331353
assert self.input_ == "i2" # sanity check
354+
repr_ = f"{x2_i} == {sign}{x1_i}"
332355

333356
def cond(i1: float, i2: float) -> bool:
334-
_cond = cond_factory(i1)
357+
_cond = make_eq(signer(i1))
335358
return _cond(i2)
336359

337-
return cond
360+
return BinaryCond(cond, repr_)
338361

339362
if m := r_not.match(group):
340363
group = m.group(1)
@@ -345,100 +368,105 @@ def cond(i1: float, i2: float) -> bool:
345368
if m := r_code.match(group):
346369
value = parse_value(m.group(1))
347370
_cond = make_eq(value)
371+
repr_template = "{} == " + str(value)
348372
elif m := r_gt.match(group):
349373
value = parse_value(m.group(1))
350374
_cond = make_gt(value)
375+
repr_template = "{} > " + str(value)
351376
elif m := r_lt.match(group):
352377
value = parse_value(m.group(1))
353378
_cond = make_lt(value)
379+
repr_template = "{} < " + str(value)
354380
elif m := r_either_code.match(group):
355381
v1 = parse_value(m.group(1))
356382
v2 = parse_value(m.group(2))
357383
_cond = make_or(make_eq(v1), make_eq(v2))
384+
repr_template = "{} == " + str(v1) + " or {} == " + str(v2)
358385
elif group in ["finite", "a finite number"]:
359386
_cond = math.isfinite
387+
repr_template = "isfinite({})"
360388
elif group in "a positive (i.e., greater than ``0``) finite number":
361389
_cond = lambda i: math.isfinite(i) and i > 0
390+
repr_template = "isfinite({}) and {} > 0"
362391
elif group == "a negative (i.e., less than ``0``) finite number":
363392
_cond = lambda i: math.isfinite(i) and i < 0
393+
repr_template = "isfinite({}) and {} < 0"
364394
elif group == "positive":
365395
_cond = lambda i: math.copysign(1, i) == 1
396+
repr_template = "copysign(1, {}) == 1"
366397
elif group == "negative":
367398
_cond = lambda i: math.copysign(1, i) == -1
399+
repr_template = "copysign(1, {}) == -1"
368400
elif "nonzero finite" in group:
369401
_cond = lambda i: math.isfinite(i) and i != 0
402+
repr_template = "copysign(1, {}) == -1"
370403
elif group == "an integer value":
371404
_cond = lambda i: i.is_integer()
405+
repr_template = "{}.is_integer()"
372406
elif group == "an odd integer value":
373407
_cond = lambda i: i.is_integer() and i % 2 == 1
408+
repr_template = "{}.is_integer() and {} % 2 == 1"
374409
else:
375-
print(f"{group=}")
376410
raise ValueParseError(group)
377411

378412
if notify:
379413
final_cond = lambda i: not _cond(i)
380414
else:
381415
final_cond = _cond
382416

417+
f_i1 = x1_i
418+
f_i2 = x2_i
419+
if self.abs_:
420+
f_i1 = f"abs{f_i1}"
421+
f_i2 = f"abs{f_i2}"
422+
383423
if self.input_ == "i1":
424+
repr_ = repr_template.replace("{}", f_i1)
384425

385426
def cond(i1: float, i2: float) -> bool:
386427
return final_cond(i1)
387428

388429
elif self.input_ == "i2":
430+
repr_ = repr_template.replace("{}", f_i2)
389431

390432
def cond(i1: float, i2: float) -> bool:
391433
return final_cond(i2)
392434

393435
elif self.input_ == "either":
436+
repr_ = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})"
394437

395438
def cond(i1: float, i2: float) -> bool:
396439
return final_cond(i1) or final_cond(i2)
397440

398441
else:
442+
assert self.input_ == "both" # sanity check
443+
repr_ = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})"
399444

400445
def cond(i1: float, i2: float) -> bool:
401446
return final_cond(i1) and final_cond(i2)
402447

403-
return cond
404-
405-
406-
@dataclass
407-
class AbsCondFactory(CondFactory):
408-
cond_factory: CondFactory
409-
410-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
411-
_cond = self.cond_factory(groups)
412-
413-
def cond(i1: float, i2: float) -> bool:
414-
i1 = abs(i1)
415-
i2 = abs(i2)
416-
return _cond(i1, i2)
448+
if notify:
449+
repr_ = f"not ({repr_})"
417450

418-
return cond
451+
return BinaryCond(cond, repr_)
419452

420453

421-
class AndCondFactory(CondFactory):
422-
def __init__(self, *cond_factories: CondFactory):
454+
class AndCondFactory(BinaryCondFactory):
455+
def __init__(self, *cond_factories: BinaryCondFactory):
423456
self.cond_factories = cond_factories
424457

425-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
458+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
426459
conds = [cond_factory(groups) for cond_factory in self.cond_factories]
460+
repr_ = " and ".join(f"({cond!r})" for cond in conds)
427461

428462
def cond(i1: float, i2: float) -> bool:
429463
return all(cond(i1, i2) for cond in conds)
430464

431-
return cond
432-
433-
def __repr__(self) -> str:
434-
f_cond_factories = ", ".join(
435-
repr(cond_factory) for cond_factory in self.cond_factories
436-
)
437-
return f"{self.__class__.__name__}({f_cond_factories})"
465+
return BinaryCond(cond, repr_)
438466

439467

440468
@dataclass
441-
class SignCondFactory(CondFactory):
469+
class SignCondFactory(BinaryCondFactory):
442470
re_groups_i: int
443471

444472
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
@@ -451,45 +479,67 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
451479
raise ValueParseError(group)
452480

453481

454-
BinaryResultCheck = Callable[[float, float, float], bool]
482+
class BinaryResultCheck(NamedTuple):
483+
check_result: Callable[[float, float, float], bool]
484+
repr_: str
485+
486+
def __call__(self, i1: float, i2: float, result: float) -> bool:
487+
return self.check_result(i1, i2, result)
488+
489+
def __repr__(self):
490+
return self.repr_
491+
492+
493+
class BinaryResultCheckFactory(Protocol):
494+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
495+
...
455496

456497

457-
class ResultCheckFactory(NamedTuple):
498+
@dataclass
499+
class ResultCheckFactory(BinaryResultCheckFactory):
458500
re_groups_i: int
459501

460502
def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
461503
group = groups[self.re_groups_i]
462504

463505
if m := r_array_element.match(group):
464-
cond_factory = make_eq if m.group(1) != "-" else make_neq
506+
sign, input_ = m.groups()
507+
if sign == "-":
508+
signer = lambda i: -i
509+
else:
510+
signer = lambda i: i
465511

466-
if m.group(2) == "1":
512+
if input_ == "1":
513+
repr_ = f"{sign}{x1_i}"
467514

468-
def cond(i1: float, i2: float, result: float) -> bool:
469-
_cond = cond_factory(i1)
470-
return _cond(result)
515+
def check_result(i1: float, i2: float, result: float) -> bool:
516+
_check_result = make_eq(signer(i1))
517+
return _check_result(result)
471518

472519
else:
520+
repr_ = f"{sign}{x2_i}"
473521

474-
def cond(i1: float, i2: float, result: float) -> bool:
475-
_cond = cond_factory(i2)
476-
return _cond(result)
522+
def check_result(i1: float, i2: float, result: float) -> bool:
523+
_check_result = make_eq(signer(i2))
524+
return _check_result(result)
477525

478-
return cond
526+
return BinaryResultCheck(check_result, repr_)
479527

480528
if m := r_code.match(group):
481529
value = parse_value(m.group(1))
482-
_cond = make_eq(value)
530+
_check_result = make_eq(value)
531+
repr_ = str(value)
483532
elif m := r_approx_value.match(group):
484533
value = parse_value(m.group(1))
485-
_cond = make_rough_eq(value)
534+
_check_result = make_rough_eq(value)
535+
repr_ = f"~{value}"
486536
else:
487537
raise ValueParseError(group)
488538

489-
def cond(i1: float, i2: float, result: float) -> bool:
490-
return _cond(result)
539+
def check_result(i1: float, i2: float, result: float) -> bool:
540+
return _check_result(result)
491541

492-
return cond
542+
return BinaryResultCheck(check_result, repr_)
493543

494544

495545
class ResultSignCheckFactory(ResultCheckFactory):
@@ -516,12 +566,15 @@ def cond(i1: float, i2: float, result: float) -> bool:
516566

517567

518568
class BinaryCase(NamedTuple):
519-
cond: BinaryCheck
569+
cond: BinaryCond
520570
check_result: BinaryResultCheck
521571

572+
def __repr__(self):
573+
return f"BinaryCase(<{self.cond} -> {self.check_result}>)"
574+
522575

523576
class BinaryCaseFactory(NamedTuple):
524-
cond_factory: CondFactory
577+
cond_factory: BinaryCondFactory
525578
check_result_factory: ResultCheckFactory
526579

527580
def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
@@ -564,9 +617,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
564617
re.compile(
565618
r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
566619
): BinaryCaseFactory(
567-
AndCondFactory(
568-
AbsCondFactory(ValueCondFactory("i1", 0)), ValueCondFactory("i2", 1)
569-
),
620+
AndCondFactory(ValueCondFactory("i1", 0, abs_=True), ValueCondFactory("i2", 1)),
570621
ResultCheckFactory(2),
571622
),
572623
re.compile(
@@ -726,25 +777,11 @@ def test_binary(func_name, func, cases, x1, x2):
726777
if case.cond(l, r):
727778
good_example = True
728779
o = float(res[o_idx])
729-
assert case.check_result(l, r, o)
730-
# f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
731-
# f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
732-
# f_out = f"{sh.fmt_idx('out', o_idx)}={out}"
733-
# if result.strict_check:
734-
# msg = (
735-
# f"{f_out}, but should be {result.repr_} [{func_name}()]\n"
736-
# f"{f_left}, {f_right}"
737-
# )
738-
# if math.isnan(result.value):
739-
# assert math.isnan(out), msg
740-
# else:
741-
# assert out == result.value, msg
742-
# else:
743-
# assert math.isfinite(result.value) # sanity check
744-
# assert math.isclose(out, result.value, abs_tol=0.1), (
745-
# f"{f_out}, but should be roughly {result.repr_}={result.value} "
746-
# f"[{func_name}()]\n"
747-
# f"{f_left}, {f_right}"
748-
# )
780+
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
781+
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
782+
f_out = f"{sh.fmt_idx('out', o_idx)}={o}"
783+
assert case.check_result(l, r, o), (
784+
f"{f_out} not good [{func_name}()]\n" f"{f_left}, {f_right}"
785+
)
749786
break
750787
assume(good_example)

0 commit comments

Comments
 (0)