Skip to content

Commit b5af57d

Browse files
committed
Parse abs special cases
1 parent fd8481e commit b5af57d

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

array_api_tests/test_special_cases.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,21 @@ def cond(i1: float, i2: float) -> bool:
403403
return cond
404404

405405

406+
@dataclass
407+
class AbsCondFactory(CondFactory):
408+
cond_factory: CondFactory
409+
410+
def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck:
411+
_cond = self.cond_factory(groups)
412+
413+
def cond(i1: float, i2: float) -> bool:
414+
i1 = abs(i1)
415+
i2 = abs(i2)
416+
return _cond(i1, i2)
417+
418+
return cond
419+
420+
406421
class AndCondFactory(CondFactory):
407422
def __init__(self, *cond_factories: CondFactory):
408423
self.cond_factories = cond_factories
@@ -546,16 +561,14 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
546561
),
547562
ResultCheckFactory(3),
548563
),
549-
# re.compile(
550-
# r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), "
551-
# "the result is (.+)"
552-
# ): make_bin_and_factory(absify_cond_factory(make_gt), make_eq),
553-
# re.compile(
554-
# r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)"
555-
# ): make_bin_and_factory(absify_cond_factory(make_lt), make_eq),
556-
# re.compile(
557-
# r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)"
558-
# ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq),
564+
re.compile(
565+
r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
566+
): BinaryCaseFactory(
567+
AndCondFactory(
568+
AbsCondFactory(ValueCondFactory("i1", 0)), ValueCondFactory("i2", 1)
569+
),
570+
ResultCheckFactory(2),
571+
),
559572
re.compile(
560573
"If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)"
561574
): BinaryCaseFactory(ValueCondFactory("either", 0), ResultCheckFactory(1)),

0 commit comments

Comments
 (0)