|
1 |
| -from inspect import getfullargspec, getmodule |
| 1 | +from inspect import signature, getmodule |
2 | 2 |
|
3 | 3 | from numpy.testing import assert_raises
|
4 | 4 |
|
|
19 | 19 |
|
20 | 20 | import pytest
|
21 | 21 |
|
| 22 | +import array_api_strict |
| 23 | + |
| 24 | + |
22 | 25 | def nargs(func):
|
23 |
| - return len(getfullargspec(func).args) |
| 26 | + """Count number of 'array' arguments a function takes.""" |
| 27 | + positional_only = 0 |
| 28 | + for param in signature(func).parameters.values(): |
| 29 | + if param.kind == param.POSITIONAL_ONLY: |
| 30 | + positional_only += 1 |
| 31 | + return positional_only |
24 | 32 |
|
25 | 33 |
|
26 | 34 | elementwise_function_input_types = {
|
@@ -91,12 +99,57 @@ def nargs(func):
|
91 | 99 | "trunc": "real numeric",
|
92 | 100 | }
|
93 | 101 |
|
| 102 | + |
| 103 | +def test_nargs(): |
| 104 | + # Explicitly check number of arguments for a few functions |
| 105 | + assert nargs(array_api_strict.logaddexp) == 2 |
| 106 | + assert nargs(array_api_strict.atan2) == 2 |
| 107 | + assert nargs(array_api_strict.clip) == 1 |
| 108 | + |
| 109 | + # All elementwise functions take one or two array arguments |
| 110 | + # if not, it is probably a bug in `nargs` or the definition |
| 111 | + # of the function (missing trailing `, /`). |
| 112 | + for func_name in elementwise_function_input_types: |
| 113 | + func = getattr(_elementwise_functions, func_name) |
| 114 | + assert nargs(func) in (1, 2) |
| 115 | + |
| 116 | + |
94 | 117 | def test_missing_functions():
|
95 | 118 | # Ensure the above dictionary is complete.
|
96 | 119 | import array_api_strict._elementwise_functions as mod
|
97 | 120 | mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
|
98 | 121 | assert set(mod_funcs) == set(elementwise_function_input_types)
|
99 | 122 |
|
| 123 | + |
| 124 | +def test_function_device_persists(): |
| 125 | + # Test that the device of the input and output array are the same |
| 126 | + def _array_vals(dtypes): |
| 127 | + for d in dtypes: |
| 128 | + yield asarray(1., dtype=d) |
| 129 | + |
| 130 | + # Use the latest version of the standard so all functions are included |
| 131 | + with pytest.warns(UserWarning): |
| 132 | + set_array_api_strict_flags(api_version="2023.12") |
| 133 | + |
| 134 | + for func_name, types in elementwise_function_input_types.items(): |
| 135 | + dtypes = _dtype_categories[types] |
| 136 | + func = getattr(_elementwise_functions, func_name) |
| 137 | + print(f"{func_name=} {nargs(func)=} {types=} {dtypes=}") |
| 138 | + |
| 139 | + for x in _array_vals(dtypes): |
| 140 | + if nargs(func) == 2: |
| 141 | + # This way we don't have to deal with incompatible |
| 142 | + # types of the two arguments. |
| 143 | + r = func(x, x) |
| 144 | + assert r.device == x.device |
| 145 | + |
| 146 | + else: |
| 147 | + if func_name == "atanh": |
| 148 | + x -= 0.1 |
| 149 | + r = func(x) |
| 150 | + assert r.device == x.device |
| 151 | + |
| 152 | + |
100 | 153 | def test_function_types():
|
101 | 154 | # Test that every function accepts only the required input types. We only
|
102 | 155 | # test the negative cases here (error). The positive cases are tested in
|
@@ -130,12 +183,12 @@ def _array_vals():
|
130 | 183 | or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
|
131 | 184 | or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
|
132 | 185 | ):
|
133 |
| - assert_raises(TypeError, lambda: func(x, y)) |
| 186 | + assert_raises(TypeError, func, x, y) |
134 | 187 | if x.dtype not in dtypes or y.dtype not in dtypes:
|
135 |
| - assert_raises(TypeError, lambda: func(x, y)) |
| 188 | + assert_raises(TypeError, func, x, y) |
136 | 189 | else:
|
137 | 190 | if x.dtype not in dtypes:
|
138 |
| - assert_raises(TypeError, lambda: func(x)) |
| 191 | + assert_raises(TypeError, func, x) |
139 | 192 |
|
140 | 193 |
|
141 | 194 | def test_bitwise_shift_error():
|
|
0 commit comments