Skip to content

Commit 0dedf55

Browse files
committed
Parse quad cases
1 parent 49ef10d commit 0dedf55

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

array_api_tests/test_special_cases.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from warnings import warn
1717

1818
import pytest
19-
from hypothesis import HealthCheck, assume, given, settings
19+
from hypothesis import assume, given
2020

2121
from . import dtype_helpers as dh
2222
from . import hypothesis_helpers as hh
@@ -580,7 +580,19 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
580580

581581
binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = {
582582
re.compile(
583-
"If ``x1_i`` is (.+), ``x1_i`` (.+), "
583+
"If ``x1_i`` is (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), "
584+
"and ``x2_i`` is (.+), the result is (.+)"
585+
): BinaryCaseFactory(
586+
AndCondFactory(
587+
ValueCondFactory("i1", 0),
588+
ValueCondFactory("i1", 1),
589+
ValueCondFactory("i2", 2),
590+
ValueCondFactory("i2", 3),
591+
),
592+
ResultCheckFactory(4),
593+
),
594+
re.compile(
595+
"If ``x1_i`` is (.+), ``x1_i`` is (.+), "
584596
"and ``x2_i`` is (.+), the result is (.+)"
585597
): BinaryCaseFactory(
586598
AndCondFactory(
@@ -591,7 +603,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
591603
ResultCheckFactory(3),
592604
),
593605
re.compile(
594-
"If ``x1_i`` is (.+), ``x2_i`` (.+), "
606+
"If ``x1_i`` is (.+), ``x2_i`` is (.+), "
595607
"and ``x2_i`` is (.+), the result is (.+)"
596608
): BinaryCaseFactory(
597609
AndCondFactory(
@@ -601,7 +613,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
601613
),
602614
ResultCheckFactory(3),
603615
),
604-
# This case must come after the above to avoid false matches
616+
# This pattern must come after the above to avoid false matches
605617
re.compile(
606618
"If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
607619
): BinaryCaseFactory(
@@ -760,7 +772,6 @@ def test_unary(func_name, func, cases, x):
760772
two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1),
761773
)
762774
)
763-
@settings(suppress_health_check=[HealthCheck.filter_too_much]) # TODO: remove
764775
def test_binary(func_name, func, cases, x1, x2):
765776
res = func(x1, x2)
766777
good_example = False

0 commit comments

Comments
 (0)