Skip to content

Commit 8b4b4f6

Browse files
committed
Add elementwise tests for positive, pow, and remainder
remainder is a bit unspecified IMO. It doesn't actually specify the behavior on nans, infinites, or remainder by 0. For now we just test reasonable behavior on those (the test passes with the NumPy implementation).
1 parent 1f70e0a commit 8b4b4f6

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
assert_integral, less_equal, isintegral, isfinite,
3737
ndindex, promote_dtypes, is_integer_dtype,
3838
is_float_dtype, not_equal, float64, asarray,
39-
dtype_ranges, full, true, false)
39+
dtype_ranges, full, true, false, assert_same_sign,
40+
isnan)
4041
# We might as well use this implementation rather than requiring
4142
# mod.broadcast_shapes(). See test_equal() and others.
4243
from .test_broadcasting import broadcast_shapes
@@ -798,20 +799,31 @@ def test_not_equal(args):
798799

799800
@given(numeric_scalars)
800801
def test_positive(x):
801-
# a = _array_module.positive(x)
802-
pass
802+
a = _array_module.positive(x)
803+
# Positive does nothing
804+
assert_exactly_equal(a, x)
803805

804806
@given(two_floating_dtypes.flatmap(lambda i: two_array_scalars(*i)))
805807
def test_pow(args):
806808
x1, x2 = args
807809
sanity_check(x1, x2)
808-
# a = _array_module.pow(x1, x2)
810+
_array_module.pow(x1, x2)
811+
# There isn't much we can test here. The spec doesn't require any behavior
812+
# beyond the special cases, and indeed, there aren't many mathematical
813+
# properties of exponentiation that strictly hold for floating-point
814+
# numbers. We could test that this does implement IEEE 754 pow, but we
815+
# don't yet have those sorts in general for this module.
809816

810817
@given(two_numeric_dtypes.flatmap(lambda i: two_array_scalars(*i)))
811818
def test_remainder(args):
812819
x1, x2 = args
813820
sanity_check(x1, x2)
814-
# a = _array_module.remainder(x1, x2)
821+
a = _array_module.remainder(x1, x2)
822+
823+
# a and x2 should have the same sign.
824+
# assert_same_sign returns False for nans
825+
not_nan = logical_not(logical_or(isnan(a), isnan(x2)))
826+
assert_same_sign(a[not_nan], x2[not_nan])
815827

816828
@given(numeric_scalars)
817829
def test_round(x):

0 commit comments

Comments
 (0)