Skip to content

Commit 645451f

Browse files
committed
Generalise not special cases
1 parent e582545 commit 645451f

File tree

1 file changed

+21
-80
lines changed

1 file changed

+21
-80
lines changed

array_api_tests/test_special_cases.py

Lines changed: 21 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
304304
...
305305

306306

307-
r_not_code = re.compile(f"not (?:equal to )?{r_code.pattern}")
307+
r_not = re.compile("not (?:equal to )?(.+)")
308308
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
309309
r_gt = re.compile(f"greater than {r_code.pattern}")
310310
r_lt = re.compile(f"less than {r_code.pattern}")
@@ -313,10 +313,10 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
313313

314314
class ValueCondFactory(NamedTuple): # TODO: inherit from CondFactory as well
315315
input_: Union[Literal["i1"], Literal["i2"]]
316-
re_group: int
316+
groups_i: int
317317

318318
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
319-
group = groups[self.re_group]
319+
group = groups[self.groups_i]
320320

321321
if m := r_array_element.match(group):
322322
cond_factory = make_eq if m.group(1) != "-" else make_neq
@@ -334,12 +334,15 @@ def cond(i1: float, i2: float) -> bool:
334334

335335
return cond
336336

337+
if m := r_not.match(group):
338+
group = m.group(1)
339+
notify = True
340+
else:
341+
notify = False
342+
337343
if m := r_code.match(group):
338344
value = parse_value(m.group(1))
339345
_cond = make_eq(value)
340-
elif m := r_not_code.match(group):
341-
value = parse_value(m.group(1))
342-
_cond = make_neq(value)
343346
elif m := r_gt.match(group):
344347
value = parse_value(m.group(1))
345348
_cond = make_gt(value)
@@ -364,24 +367,26 @@ def cond(i1: float, i2: float) -> bool:
364367
_cond = lambda i: math.isfinite(i) and i != 0
365368
elif group == "an integer value":
366369
_cond = lambda i: i.is_integer()
367-
elif group == "not an integer value":
368-
_cond = lambda i: not i.is_integer()
369370
elif group == "an odd integer value":
370371
_cond = lambda i: i.is_integer() and i % 2 == 1
371-
elif group == "not an odd integer value":
372-
_cond = lambda i: not (i.is_integer() and i % 2 == 1)
373372
else:
373+
print(f"{group=}")
374374
raise ValueParseError(group)
375375

376+
if notify:
377+
final_cond = lambda i: not _cond(i)
378+
else:
379+
final_cond = _cond
380+
376381
if self.input_ == "i1":
377382

378383
def cond(i1: float, i2: float) -> bool:
379-
return _cond(i1)
384+
return final_cond(i1)
380385

381386
else:
382387

383388
def cond(i1: float, i2: float) -> bool:
384-
return _cond(i2)
389+
return final_cond(i2)
385390

386391
return cond
387392

@@ -409,10 +414,10 @@ def __repr__(self) -> str:
409414

410415

411416
class ResultCheckFactory(NamedTuple):
412-
re_group: int
417+
groups_i: int
413418

414419
def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck:
415-
group = groups[self.re_group]
420+
group = groups[self.groups_i]
416421

417422
if m := r_array_element.match(group):
418423
cond_factory = make_eq if m.group(1) != "-" else make_neq
@@ -472,54 +477,9 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
472477
# "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
473478
# ): lambda v: lambda _, i2: make_eq(v)(i2),
474479
# re.compile(
475-
# "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)"
476-
# ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v),
477-
# re.compile(
478-
# "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), "
479-
# "and ``x2_i`` is (.+), the result is (.+)"
480-
# ): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]),
481-
# re.compile(
482-
# "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), "
483-
# "and ``x2_i`` is (.+), the result is (.+)"
484-
# ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]),
485-
# re.compile(
486-
# "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), "
487-
# "and ``x2_i`` is not (.+), the result is (.+)"
488-
# ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]),
489-
# re.compile(
490-
# "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
491-
# "and ``x2_i`` is (.+), the result is (.+)"
492-
# ): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]),
493-
# re.compile(
494-
# "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), "
495-
# "and ``x2_i`` is not (.+), the result is (.+)"
496-
# ): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]),
497-
# re.compile(
498-
# "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
480+
# "If ``x1_i`` is (.+), ``x1_i`` (.+), "
499481
# "and ``x2_i`` is (.+), the result is (.+)"
500-
# ): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]),
501-
# re.compile(
502-
# "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), "
503-
# "and ``x2_i`` is not (.+), the result is (.+)"
504-
# ): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]),
505-
# re.compile(
506-
# "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)"
507-
# ): make_bin_and_factory(make_gt, make_eq),
508-
# re.compile(
509-
# "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)"
510-
# ): make_bin_and_factory(make_eq, make_gt),
511-
# re.compile(
512-
# "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
513-
# ): make_bin_and_factory(make_lt, make_eq),
514-
# re.compile(
515-
# "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)"
516-
# ): make_bin_and_factory(make_eq, make_lt),
517-
# re.compile(
518-
# "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)"
519-
# ): make_bin_and_factory(make_neq, make_eq),
520-
# re.compile(
521-
# "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)"
522-
# ): make_bin_and_factory(make_eq, make_neq),
482+
# )
523483
# re.compile(
524484
# r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), "
525485
# "the result is (.+)"
@@ -537,25 +497,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
537497
# "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
538498
# ): make_bin_or_factory(make_eq),
539499
# re.compile(
540-
# "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)"
541-
# ): lambda v1, v2, v3: (
542-
# lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2)
543-
# ),
544-
# re.compile(
545-
# "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)"
546-
# ): lambda v1, v2, v3: (
547-
# lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2)
548-
# ),
549-
# re.compile(
550-
# "If ``x1_i`` is either (.+) or (.+) and "
551-
# "``x2_i`` is either (.+) or (.+), the result is (.+)"
552-
# ): lambda v1, v2, v3, v4: (
553-
# lambda i1, i2: (
554-
# make_or(make_eq(v1), make_eq(v2))(i1)
555-
# and make_or(make_eq(v3), make_eq(v4))(i2)
556-
# )
557-
# ),
558-
# re.compile(
559500
# "If ``x1_i`` and ``x2_i`` have the same mathematical sign, "
560501
# "the result has a (.+)"
561502
# ): lambda: same_sign,

0 commit comments

Comments
 (0)