Skip to content

Commit fd8481e

Browse files
committed
Parse more awkward equality special cases
1 parent f3f683f commit fd8481e

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

array_api_tests/test_special_cases.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import math
33
import re
4+
from dataclasses import dataclass
45
from typing import (
56
Callable,
67
Dict,
@@ -15,7 +16,6 @@
1516
from warnings import warn
1617

1718
import pytest
18-
from attr import dataclass
1919
from hypothesis import HealthCheck, assume, given, settings
2020

2121
from . import dtype_helpers as dh
@@ -524,13 +524,28 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
524524
AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)),
525525
ResultCheckFactory(2),
526526
),
527-
# re.compile(
528-
# "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
529-
# ): lambda v: lambda _, i2: make_eq(v)(i2),
530-
# re.compile(
531-
# "If ``x1_i`` is (.+), ``x1_i`` (.+), "
532-
# "and ``x2_i`` is (.+), the result is (.+)"
533-
# )
527+
re.compile(
528+
"If ``x1_i`` is (.+), ``x1_i`` (.+), "
529+
"and ``x2_i`` is (.+), the result is (.+)"
530+
): BinaryCaseFactory(
531+
AndCondFactory(
532+
ValueCondFactory("i1", 0),
533+
ValueCondFactory("i1", 1),
534+
ValueCondFactory("i2", 2),
535+
),
536+
ResultCheckFactory(3),
537+
),
538+
re.compile(
539+
"If ``x1_i`` is (.+), ``x2_i`` (.+), "
540+
"and ``x2_i`` is (.+), the result is (.+)"
541+
): BinaryCaseFactory(
542+
AndCondFactory(
543+
ValueCondFactory("i1", 0),
544+
ValueCondFactory("i2", 1),
545+
ValueCondFactory("i2", 2),
546+
),
547+
ResultCheckFactory(3),
548+
),
534549
# re.compile(
535550
# r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), "
536551
# "the result is (.+)"
@@ -560,6 +575,9 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
560575
rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result "
561576
r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\."
562577
): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)),
578+
re.compile(
579+
"If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+"
580+
): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)),
563581
}
564582

565583

0 commit comments

Comments
 (0)