Skip to content

Commit 1f70e0a

Browse files
committed
Let the same_sign helper work with integer dtypes
1 parent 5269fae commit 1f70e0a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

array_api_tests/array_helpers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,24 @@ def false(shape):
106106

107107
def isnegzero(x):
108108
"""
109-
Returns a mask where x is -0.
109+
Returns a mask where x is -0. Is all False if x has integer dtype.
110110
"""
111111
# TODO: If copysign or signbit are added to the spec, use those instead.
112112
shape = x.shape
113113
dtype = x.dtype
114+
if is_integer_dtype(dtype):
115+
return false(shape)
114116
return equal(divide(one(shape, dtype), x), -infinity(shape, dtype))
115117

116118
def isposzero(x):
117119
"""
118-
Returns a mask where x is +0 (but not -0).
120+
Returns a mask where x is +0 (but not -0). Is all True if x has integer dtype.
119121
"""
120122
# TODO: If copysign or signbit are added to the spec, use those instead.
121123
shape = x.shape
122124
dtype = x.dtype
125+
if is_integer_dtype(dtype):
126+
return true(shape)
123127
return equal(divide(one(shape, dtype), x), infinity(shape, dtype))
124128

125129
def exactly_equal(x, y):
@@ -307,7 +311,6 @@ def same_sign(x, y):
307311
def assert_same_sign(x, y):
308312
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
309313

310-
311314
integer_dtype_objects = [getattr(_array_module, t) for t in _integer_dtypes]
312315
floating_dtype_objects = [getattr(_array_module, t) for t in _floating_dtypes]
313316
numeric_dtype_objects = [getattr(_array_module, t) for t in _numeric_dtypes]

0 commit comments

Comments
 (0)