Skip to content

Commit dfda4f5

Browse files
committed
Pass but filter out-of-range values for trig function tests
1 parent 56aa06d commit dfda4f5

File tree

1 file changed

+40
-64
lines changed

1 file changed

+40
-64
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 40 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -446,30 +446,24 @@ def test_abs(ctx, data):
446446
)
447447

448448

449-
@given(
450-
xps.arrays(
451-
dtype=xps.floating_dtypes(),
452-
shape=hh.shapes(),
453-
elements={"min_value": -1, "max_value": 1},
454-
)
455-
)
449+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
456450
def test_acos(x):
457451
out = xp.acos(x)
458452
ph.assert_dtype("acos", x.dtype, out.dtype)
459453
ph.assert_shape("acos", out.shape, x.shape)
460-
unary_assert_against_refimpl("acos", x, out, math.acos)
454+
unary_assert_against_refimpl(
455+
"acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1
456+
)
461457

462458

463-
@given(
464-
xps.arrays(
465-
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
466-
)
467-
)
459+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
468460
def test_acosh(x):
469461
out = xp.acosh(x)
470462
ph.assert_dtype("acosh", x.dtype, out.dtype)
471463
ph.assert_shape("acosh", out.shape, x.shape)
472-
unary_assert_against_refimpl("acosh", x, out, math.acosh)
464+
unary_assert_against_refimpl(
465+
"acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1
466+
)
473467

474468

475469
@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
@@ -488,18 +482,14 @@ def test_add(ctx, data):
488482
binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add)
489483

490484

491-
@given(
492-
xps.arrays(
493-
dtype=xps.floating_dtypes(),
494-
shape=hh.shapes(),
495-
elements={"min_value": -1, "max_value": 1},
496-
)
497-
)
485+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
498486
def test_asin(x):
499487
out = xp.asin(x)
500488
ph.assert_dtype("asin", x.dtype, out.dtype)
501489
ph.assert_shape("asin", out.shape, x.shape)
502-
unary_assert_against_refimpl("asin", x, out, math.asin)
490+
unary_assert_against_refimpl(
491+
"asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1
492+
)
503493

504494

505495
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -526,18 +516,18 @@ def test_atan2(x1, x2):
526516
binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2)
527517

528518

529-
@given(
530-
xps.arrays(
531-
dtype=xps.floating_dtypes(),
532-
shape=hh.shapes(),
533-
elements={"min_value": -1, "max_value": 1},
534-
)
535-
)
519+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
536520
def test_atanh(x):
537521
out = xp.atanh(x)
538522
ph.assert_dtype("atanh", x.dtype, out.dtype)
539523
ph.assert_shape("atanh", out.shape, x.shape)
540-
unary_assert_against_refimpl("atanh", x, out, math.atanh)
524+
unary_assert_against_refimpl(
525+
"atanh",
526+
x,
527+
out,
528+
math.atanh,
529+
filter_=lambda s: default_filter(s) and -1 <= s <= 1,
530+
)
541531

542532

543533
@pytest.mark.parametrize(
@@ -899,56 +889,44 @@ def test_less_equal(ctx, data):
899889
)
900890

901891

902-
@given(
903-
xps.arrays(
904-
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
905-
)
906-
)
892+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
907893
def test_log(x):
908894
out = xp.log(x)
909895
ph.assert_dtype("log", x.dtype, out.dtype)
910896
ph.assert_shape("log", out.shape, x.shape)
911-
unary_assert_against_refimpl("log", x, out, math.log)
897+
unary_assert_against_refimpl(
898+
"log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1
899+
)
912900

913901

914-
@given(
915-
xps.arrays(
916-
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 1}
917-
)
918-
)
902+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
919903
def test_log1p(x):
920904
out = xp.log1p(x)
921905
ph.assert_dtype("log1p", x.dtype, out.dtype)
922906
ph.assert_shape("log1p", out.shape, x.shape)
923-
unary_assert_against_refimpl("log1p", x, out, math.log1p)
907+
unary_assert_against_refimpl(
908+
"log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1
909+
)
924910

925911

926-
@given(
927-
xps.arrays(
928-
dtype=xps.floating_dtypes(),
929-
shape=hh.shapes(),
930-
elements={"min_value": 0, "exclude_min": True},
931-
)
932-
)
912+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
933913
def test_log2(x):
934914
out = xp.log2(x)
935915
ph.assert_dtype("log2", x.dtype, out.dtype)
936916
ph.assert_shape("log2", out.shape, x.shape)
937-
unary_assert_against_refimpl("log2", x, out, math.log2)
917+
unary_assert_against_refimpl(
918+
"log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1
919+
)
938920

939921

940-
@given(
941-
xps.arrays(
942-
dtype=xps.floating_dtypes(),
943-
shape=hh.shapes(),
944-
elements={"min_value": 0, "exclude_min": True},
945-
)
946-
)
922+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
947923
def test_log10(x):
948924
out = xp.log10(x)
949925
ph.assert_dtype("log10", x.dtype, out.dtype)
950926
ph.assert_shape("log10", out.shape, x.shape)
951-
unary_assert_against_refimpl("log10", x, out, math.log10)
927+
unary_assert_against_refimpl(
928+
"log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0
929+
)
952930

953931

954932
@given(*hh.two_mutual_arrays(dh.float_dtypes))
@@ -1166,16 +1144,14 @@ def test_square(x):
11661144
)
11671145

11681146

1169-
@given(
1170-
xps.arrays(
1171-
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 0}
1172-
)
1173-
)
1147+
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
11741148
def test_sqrt(x):
11751149
out = xp.sqrt(x)
11761150
ph.assert_dtype("sqrt", x.dtype, out.dtype)
11771151
ph.assert_shape("sqrt", out.shape, x.shape)
1178-
unary_assert_against_refimpl("sqrt", x, out, math.sqrt)
1152+
unary_assert_against_refimpl(
1153+
"sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0
1154+
)
11791155

11801156

11811157
@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))

0 commit comments

Comments
 (0)