Skip to content

Commit 727072f

Browse files
committed
Add testing and small typo fixes
1 parent 426609f commit 727072f

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

array_api_strict/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __hash__(self):
121121
"integer": _integer_dtypes,
122122
"integer or boolean": _integer_or_boolean_dtypes,
123123
"boolean": _boolean_dtypes,
124-
"real floating-point": _floating_dtypes,
124+
"real floating-point": _real_floating_dtypes,
125125
"complex floating-point": _complex_floating_dtypes,
126126
"floating-point": _floating_dtypes,
127127
}

array_api_strict/_elementwise_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def copysign(x1: Array, x2: Array, /) -> Array:
354354
if x1.device != x2.device:
355355
raise RuntimeError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
356356

357-
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
357+
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
358358
raise TypeError("Only real numeric dtypes are allowed in copysign")
359359
# Call result type here just to raise on disallowed type combinations
360360
_result_type(x1.dtype, x2.dtype)
@@ -632,7 +632,7 @@ def log10(x: Array, /) -> Array:
632632
return Array._new(np.log10(x._array), device=x.device)
633633

634634

635-
def logaddexp(x1: Array, x2: Array) -> Array:
635+
def logaddexp(x1: Array, x2: Array, /) -> Array:
636636
"""
637637
Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
638638

array_api_strict/tests/test_elementwise_functions.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from inspect import getfullargspec, getmodule
1+
from inspect import signature, getmodule
22

33
from numpy.testing import assert_raises
44

@@ -19,8 +19,16 @@
1919

2020
import pytest
2121

22+
import array_api_strict
23+
24+
2225
def nargs(func):
23-
return len(getfullargspec(func).args)
26+
"""Count number of 'array' arguments a function takes."""
27+
positional_only = 0
28+
for param in signature(func).parameters.values():
29+
if param.kind == param.POSITIONAL_ONLY:
30+
positional_only += 1
31+
return positional_only
2432

2533

2634
elementwise_function_input_types = {
@@ -91,12 +99,57 @@ def nargs(func):
9199
"trunc": "real numeric",
92100
}
93101

102+
103+
def test_nargs():
104+
# Explicitly check number of arguments for a few functions
105+
assert nargs(array_api_strict.logaddexp) == 2
106+
assert nargs(array_api_strict.atan2) == 2
107+
assert nargs(array_api_strict.clip) == 1
108+
109+
# All elementwise functions take one or two array arguments
110+
# if not, it is probably a bug in `nargs` or the definition
111+
# of the function (missing trailing `, /`).
112+
for func_name in elementwise_function_input_types:
113+
func = getattr(_elementwise_functions, func_name)
114+
assert nargs(func) in (1, 2)
115+
116+
94117
def test_missing_functions():
95118
# Ensure the above dictionary is complete.
96119
import array_api_strict._elementwise_functions as mod
97120
mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
98121
assert set(mod_funcs) == set(elementwise_function_input_types)
99122

123+
124+
def test_function_device_persists():
125+
# Test that the device of the input and output array are the same
126+
def _array_vals(dtypes):
127+
for d in dtypes:
128+
yield asarray(1., dtype=d)
129+
130+
# Use the latest version of the standard so all functions are included
131+
with pytest.warns(UserWarning):
132+
set_array_api_strict_flags(api_version="2023.12")
133+
134+
for func_name, types in elementwise_function_input_types.items():
135+
dtypes = _dtype_categories[types]
136+
func = getattr(_elementwise_functions, func_name)
137+
print(f"{func_name=} {nargs(func)=} {types=} {dtypes=}")
138+
139+
for x in _array_vals(dtypes):
140+
if nargs(func) == 2:
141+
# This way we don't have to deal with incompatible
142+
# types of the two arguments.
143+
r = func(x, x)
144+
assert r.device == x.device
145+
146+
else:
147+
if func_name == "atanh":
148+
x -= 0.1
149+
r = func(x)
150+
assert r.device == x.device
151+
152+
100153
def test_function_types():
101154
# Test that every function accepts only the required input types. We only
102155
# test the negative cases here (error). The positive cases are tested in
@@ -130,12 +183,12 @@ def _array_vals():
130183
or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
131184
or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
132185
):
133-
assert_raises(TypeError, lambda: func(x, y))
186+
assert_raises(TypeError, func, x, y)
134187
if x.dtype not in dtypes or y.dtype not in dtypes:
135-
assert_raises(TypeError, lambda: func(x, y))
188+
assert_raises(TypeError, func, x, y)
136189
else:
137190
if x.dtype not in dtypes:
138-
assert_raises(TypeError, lambda: func(x))
191+
assert_raises(TypeError, func, x)
139192

140193

141194
def test_bitwise_shift_error():

0 commit comments

Comments
 (0)