Skip to content

Commit 22adf72

Browse files
committed
repr_ -> expr, return BinaryCond in SignCondFactory
1 parent 3971c3f commit 22adf72

File tree

1 file changed

+88
-100
lines changed

1 file changed

+88
-100
lines changed

array_api_tests/test_special_cases.py

Lines changed: 88 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -222,27 +222,6 @@ def parse_inline_code(inline_code: str) -> float:
222222
raise ValueParseError(inline_code)
223223

224224

225-
class Result(NamedTuple):
226-
value: float
227-
repr_: str
228-
strict_check: bool
229-
230-
231-
def parse_result(s_result: str) -> Result:
232-
match = None
233-
if m := r_code.match(s_result):
234-
match = m
235-
strict_check = True
236-
elif m := r_approx_value.match(s_result):
237-
match = m
238-
strict_check = False
239-
else:
240-
raise ValueParseError(s_result)
241-
value = parse_value(match.group(1))
242-
repr_ = match.group(1)
243-
return Result(value, repr_, strict_check)
244-
245-
246225
r_special_cases = re.compile(
247226
r"\*\*Special [Cc]ases\*\*\n+\s*"
248227
r"For floating-point operands,\n+"
@@ -252,65 +231,74 @@ def parse_result(s_result: str) -> Result:
252231
r_remaining_case = re.compile("In the remaining cases.+")
253232

254233

255-
unary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
256-
re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt,
257-
re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt,
258-
re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): (
259-
lambda v1, v2: make_or(make_eq(v1), make_eq(v2))
260-
),
261-
# This pattern must come after the previous patterns to avoid unwanted matches
262-
re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq,
263-
re.compile(
264-
"If two integers are equally close to ``x_i``, the result is (.+)"
265-
): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
266-
}
234+
# unary_pattern_to_condition_factory: Dict[Pattern, Callable] = {
235+
# re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt,
236+
# re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt,
237+
# re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): (
238+
# lambda v1, v2: make_or(make_eq(v1), make_eq(v2))
239+
# ),
240+
# # This pattern must come after the previous patterns to avoid unwanted matches
241+
# re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq,
242+
# re.compile(
243+
# "If two integers are equally close to ``x_i``, the result is (.+)"
244+
# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
245+
# }
246+
247+
248+
# def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
249+
# match = r_special_cases.search(docstring)
250+
# if match is None:
251+
# return {}
252+
# cases = match.group(1).split("\n")[:-1]
253+
# cases = {}
254+
# for line in cases:
255+
# if m := r_case.match(line):
256+
# case = m.group(1)
257+
# else:
258+
# warn(f"line not machine-readable: '{line}'")
259+
# continue
260+
# for pattern, make_cond in unary_pattern_to_condition_factory.items():
261+
# if m := pattern.search(case):
262+
# *s_values, s_result = m.groups()
263+
# try:
264+
# values = [parse_inline_code(v) for v in s_values]
265+
# except ValueParseError as e:
266+
# warn(f"value not machine-readable: '{e.value}'")
267+
# break
268+
# cond = make_cond(*values)
269+
# try:
270+
# result = parse_result(s_result)
271+
# except ValueParseError as e:
272+
# warn(f"result not machine-readable: '{e.value}'")
273+
274+
# break
275+
# cases[cond] = result
276+
# break
277+
# else:
278+
# if not r_remaining_case.search(case):
279+
# warn(f"case not machine-readable: '{case}'")
280+
# return cases
281+
282+
x_i = "xᵢ"
283+
x1_i = "x1ᵢ"
284+
x2_i = "x2ᵢ"
267285

268286

269-
def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
270-
match = r_special_cases.search(docstring)
271-
if match is None:
272-
return {}
273-
cases = match.group(1).split("\n")[:-1]
274-
cases = {}
275-
for line in cases:
276-
if m := r_case.match(line):
277-
case = m.group(1)
278-
else:
279-
warn(f"line not machine-readable: '{line}'")
280-
continue
281-
for pattern, make_cond in unary_pattern_to_condition_factory.items():
282-
if m := pattern.search(case):
283-
*s_values, s_result = m.groups()
284-
try:
285-
values = [parse_inline_code(v) for v in s_values]
286-
except ValueParseError as e:
287-
warn(f"value not machine-readable: '{e.value}'")
288-
break
289-
cond = make_cond(*values)
290-
try:
291-
result = parse_result(s_result)
292-
except ValueParseError as e:
293-
warn(f"result not machine-readable: '{e.value}'")
287+
class Cond(Protocol):
288+
expr: str
294289

295-
break
296-
cases[cond] = result
297-
break
298-
else:
299-
if not r_remaining_case.search(case):
300-
warn(f"case not machine-readable: '{case}'")
301-
return cases
290+
def __call__(self, *args) -> bool:
291+
...
302292

303293

304-
class BinaryCond(NamedTuple):
294+
@dataclass
295+
class BinaryCond(Cond):
305296
cond: BinaryCheck
306-
repr_: str
297+
expr: str
307298

308299
def __call__(self, i1: float, i2: float) -> bool:
309300
return self.cond(i1, i2)
310301

311-
def __repr__(self):
312-
return self.repr_
313-
314302

315303
class BinaryCondFactory(Protocol):
316304
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
@@ -323,9 +311,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
323311
r_gt = re.compile(f"greater than {r_code.pattern}")
324312
r_lt = re.compile(f"less than {r_code.pattern}")
325313

326-
x1_i = "x1ᵢ"
327-
x2_i = "x2ᵢ"
328-
329314

330315
@dataclass
331316
class ValueCondFactory(BinaryCondFactory):
@@ -345,21 +330,21 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
345330
signer = lambda i: i
346331

347332
if self.input_ == "i1":
348-
repr_ = f"{x1_i} == {sign}{x2_i}"
333+
expr = f"{x1_i} == {sign}{x2_i}"
349334

350335
def cond(i1: float, i2: float) -> bool:
351336
_cond = make_eq(signer(i2))
352337
return _cond(i1)
353338

354339
else:
355340
assert self.input_ == "i2" # sanity check
356-
repr_ = f"{x2_i} == {sign}{x1_i}"
341+
expr = f"{x2_i} == {sign}{x1_i}"
357342

358343
def cond(i1: float, i2: float) -> bool:
359344
_cond = make_eq(signer(i1))
360345
return _cond(i2)
361346

362-
return BinaryCond(cond, repr_)
347+
return BinaryCond(cond, expr)
363348

364349
if m := r_not.match(group):
365350
group = m.group(1)
@@ -419,38 +404,38 @@ def cond(i1: float, i2: float) -> bool:
419404
f_i1 = x1_i
420405
f_i2 = x2_i
421406
if self.abs_:
422-
f_i1 = f"abs{f_i1}"
423-
f_i2 = f"abs{f_i2}"
407+
f_i1 = f"abs({f_i1})"
408+
f_i2 = f"abs({f_i2})"
424409

425410
if self.input_ == "i1":
426-
repr_ = repr_template.replace("{}", f_i1)
411+
expr = repr_template.replace("{}", f_i1)
427412

428413
def cond(i1: float, i2: float) -> bool:
429414
return final_cond(i1)
430415

431416
elif self.input_ == "i2":
432-
repr_ = repr_template.replace("{}", f_i2)
417+
expr = repr_template.replace("{}", f_i2)
433418

434419
def cond(i1: float, i2: float) -> bool:
435420
return final_cond(i2)
436421

437422
elif self.input_ == "either":
438-
repr_ = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})"
423+
expr = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})"
439424

440425
def cond(i1: float, i2: float) -> bool:
441426
return final_cond(i1) or final_cond(i2)
442427

443428
else:
444429
assert self.input_ == "both" # sanity check
445-
repr_ = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})"
430+
expr = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})"
446431

447432
def cond(i1: float, i2: float) -> bool:
448433
return final_cond(i1) and final_cond(i2)
449434

450435
if notify:
451-
repr_ = f"not ({repr_})"
436+
expr = f"not ({expr})"
452437

453-
return BinaryCond(cond, repr_)
438+
return BinaryCond(cond, expr)
454439

455440

456441
class AndCondFactory(BinaryCondFactory):
@@ -459,37 +444,40 @@ def __init__(self, *cond_factories: BinaryCondFactory):
459444

460445
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
461446
conds = [cond_factory(groups) for cond_factory in self.cond_factories]
462-
repr_ = " and ".join(f"({cond!r})" for cond in conds)
447+
expr = " and ".join(f"({cond.expr})" for cond in conds)
463448

464449
def cond(i1: float, i2: float) -> bool:
465450
return all(cond(i1, i2) for cond in conds)
466451

467-
return BinaryCond(cond, repr_)
452+
return BinaryCond(cond, expr)
468453

469454

470455
@dataclass
471456
class SignCondFactory(BinaryCondFactory):
472457
re_groups_i: int
473458

474-
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
459+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCond:
475460
group = groups[self.re_groups_i]
476461
if group == "the same mathematical sign":
477-
return same_sign
462+
cond = same_sign
463+
expr = f"copysign(1, {x1_i}) == copysign(1, {x2_i})"
478464
elif group == "different mathematical signs":
479-
return diff_sign
465+
cond = diff_sign
466+
expr = f"copysign(1, {x1_i}) != copysign(1, {x2_i})"
480467
else:
481468
raise ValueParseError(group)
469+
return BinaryCond(cond, expr)
482470

483471

484472
class BinaryResultCheck(NamedTuple):
485473
check_result: Callable[[float, float, float], bool]
486-
repr_: str
474+
expr: str
487475

488476
def __call__(self, i1: float, i2: float, result: float) -> bool:
489477
return self.check_result(i1, i2, result)
490478

491479
def __repr__(self):
492-
return self.repr_
480+
return self.expr
493481

494482

495483
class BinaryResultCheckFactory(Protocol):
@@ -512,36 +500,36 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
512500
signer = lambda i: i
513501

514502
if input_ == "1":
515-
repr_ = f"{sign}{x1_i}"
503+
expr = f"{sign}{x1_i}"
516504

517505
def check_result(i1: float, i2: float, result: float) -> bool:
518506
_check_result = make_eq(signer(i1))
519507
return _check_result(result)
520508

521509
else:
522-
repr_ = f"{sign}{x2_i}"
510+
expr = f"{sign}{x2_i}"
523511

524512
def check_result(i1: float, i2: float, result: float) -> bool:
525513
_check_result = make_eq(signer(i2))
526514
return _check_result(result)
527515

528-
return BinaryResultCheck(check_result, repr_)
516+
return BinaryResultCheck(check_result, expr)
529517

530518
if m := r_code.match(group):
531519
value = parse_value(m.group(1))
532520
_check_result = make_eq(value)
533-
repr_ = str(value)
521+
expr = str(value)
534522
elif m := r_approx_value.match(group):
535523
value = parse_value(m.group(1))
536524
_check_result = make_rough_eq(value)
537-
repr_ = f"~{value}"
525+
expr = f"~{value}"
538526
else:
539527
raise ValueParseError(group)
540528

541529
def check_result(i1: float, i2: float, result: float) -> bool:
542530
return _check_result(result)
543531

544-
return BinaryResultCheck(check_result, repr_)
532+
return BinaryResultCheck(check_result, expr)
545533

546534

547535
class ResultSignCheckFactory(ResultCheckFactory):
@@ -572,7 +560,7 @@ class BinaryCase(NamedTuple):
572560
check_result: BinaryResultCheck
573561

574562
def __repr__(self):
575-
return f"BinaryCase(<{self.cond} -> {self.check_result}>)"
563+
return f"BinaryCase(<{self.cond.expr} -> {self.check_result}>)"
576564

577565

578566
class BinaryCaseFactory(NamedTuple):
@@ -743,7 +731,7 @@ def test_unary(func_name, func, cases, x):
743731
f_out = f"{sh.fmt_idx('out', idx)}={out}"
744732
if result.strict_check:
745733
msg = (
746-
f"{f_out}, but should be {result.repr_} [{func_name}()]\n"
734+
f"{f_out}, but should be {result.expr} [{func_name}()]\n"
747735
f"{f_in}"
748736
)
749737
if math.isnan(result.value):
@@ -753,7 +741,7 @@ def test_unary(func_name, func, cases, x):
753741
else:
754742
assert math.isfinite(result.value) # sanity check
755743
assert math.isclose(out, result.value, abs_tol=0.1), (
756-
f"{f_out}, but should be roughly {result.repr_}={result.value} "
744+
f"{f_out}, but should be roughly {result.expr}={result.value} "
757745
f"[{func_name}()]\n"
758746
f"{f_in}"
759747
)

0 commit comments

Comments
 (0)