|
7 | 7 |
|
8 | 8 | import math
|
9 | 9 | from types import ModuleType
|
10 |
| -from typing import cast |
| 10 | +from typing import Any, cast |
11 | 11 |
|
12 | 12 | import numpy as np
|
13 | 13 | import pytest
|
|
17 | 17 | is_array_api_strict_namespace,
|
18 | 18 | is_cupy_namespace,
|
19 | 19 | is_dask_namespace,
|
| 20 | + is_jax_namespace, |
20 | 21 | is_numpy_namespace,
|
21 | 22 | is_pydata_sparse_namespace,
|
22 | 23 | is_torch_namespace,
|
| 24 | + to_device, |
23 | 25 | )
|
24 |
| -from ._utils._typing import Array |
| 26 | +from ._utils._typing import Array, Device |
25 | 27 |
|
26 |
| -__all__ = ["xp_assert_close", "xp_assert_equal"] |
| 28 | +__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"] |
27 | 29 |
|
28 | 30 |
|
29 | 31 | def _check_ns_shape_dtype(
|
@@ -81,23 +83,28 @@ def _check_ns_shape_dtype(
|
81 | 83 | return desired_xp
|
82 | 84 |
|
83 | 85 |
|
84 |
| -def _prepare_for_test(array: Array, xp: ModuleType) -> Array: |
| 86 | +def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any] |
85 | 87 | """
|
86 |
| - Ensure that the array can be compared with np.testing. |
87 |
| -
|
88 |
| - This involves transferring it from GPU to CPU memory, densifying it, etc. |
| 88 | + Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards. |
89 | 89 | """
|
90 |
| - if is_torch_namespace(xp): |
91 |
| - return np.asarray(array.cpu()) # type: ignore[attr-defined, return-value] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType] |
| 90 | + if is_cupy_namespace(xp): |
| 91 | + return xp.asnumpy(array) |
92 | 92 | if is_pydata_sparse_namespace(xp):
|
93 | 93 | return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
|
| 94 | + |
| 95 | + if is_torch_namespace(xp): |
| 96 | + array = to_device(array, "cpu") |
94 | 97 | if is_array_api_strict_namespace(xp):
|
95 |
| - # Note: we deliberately did not add a `.to_device` method in _typing.pyi |
96 |
| - # even if it is required by the standard as many backends don't support it |
97 |
| - return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] |
98 |
| - if is_cupy_namespace(xp): |
99 |
| - return xp.asnumpy(array) |
100 |
| - return array |
| 98 | + cpu: Device = xp.Device("CPU_DEVICE") |
| 99 | + array = to_device(array, cpu) |
| 100 | + if is_jax_namespace(xp): |
| 101 | + import jax |
| 102 | + |
| 103 | + # Note: only needed if the transfer guard is enabled |
| 104 | + cpu = cast(Device, jax.devices("cpu")[0]) |
| 105 | + array = to_device(array, cpu) |
| 106 | + |
| 107 | + return np.asarray(array) |
101 | 108 |
|
102 | 109 |
|
103 | 110 | def xp_assert_equal(
|
@@ -132,9 +139,9 @@ def xp_assert_equal(
|
132 | 139 | numpy.testing.assert_array_equal : Similar function for NumPy arrays.
|
133 | 140 | """
|
134 | 141 | xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
|
135 |
| - actual = _prepare_for_test(actual, xp) |
136 |
| - desired = _prepare_for_test(desired, xp) |
137 |
| - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) |
| 142 | + actual_np = as_numpy_array(actual, xp=xp) |
| 143 | + desired_np = as_numpy_array(desired, xp=xp) |
| 144 | + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) |
138 | 145 |
|
139 | 146 |
|
140 | 147 | def xp_assert_less(
|
@@ -167,9 +174,9 @@ def xp_assert_less(
|
167 | 174 | numpy.testing.assert_array_equal : Similar function for NumPy arrays.
|
168 | 175 | """
|
169 | 176 | xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
|
170 |
| - x = _prepare_for_test(x, xp) |
171 |
| - y = _prepare_for_test(y, xp) |
172 |
| - np.testing.assert_array_less(x, y, err_msg=err_msg) # type: ignore[call-overload] |
| 177 | + x_np = as_numpy_array(x, xp=xp) |
| 178 | + y_np = as_numpy_array(y, xp=xp) |
| 179 | + np.testing.assert_array_less(x_np, y_np, err_msg=err_msg) |
173 | 180 |
|
174 | 181 |
|
175 | 182 | def xp_assert_close(
|
@@ -216,23 +223,21 @@ def xp_assert_close(
|
216 | 223 | """
|
217 | 224 | xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
|
218 | 225 |
|
219 |
| - floating = xp.isdtype(actual.dtype, ("real floating", "complex floating")) |
220 |
| - if rtol is None and floating: |
221 |
| - # multiplier of 4 is used as for `np.float64` this puts the default `rtol` |
222 |
| - # roughly half way between sqrt(eps) and the default for |
223 |
| - # `numpy.testing.assert_allclose`, 1e-7 |
224 |
| - rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 |
225 |
| - elif rtol is None: |
226 |
| - rtol = 1e-7 |
227 |
| - |
228 |
| - actual = _prepare_for_test(actual, xp) |
229 |
| - desired = _prepare_for_test(desired, xp) |
230 |
| - |
231 |
| - # JAX/Dask arrays work directly with `np.testing` |
232 |
| - np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue] |
233 |
| - actual, # pyright: ignore[reportArgumentType] |
234 |
| - desired, # pyright: ignore[reportArgumentType] |
235 |
| - rtol=rtol, |
| 226 | + if rtol is None: |
| 227 | + if xp.isdtype(actual.dtype, ("real floating", "complex floating")): |
| 228 | + # multiplier of 4 is used as for `np.float64` this puts the default `rtol` |
| 229 | + # roughly half way between sqrt(eps) and the default for |
| 230 | + # `numpy.testing.assert_allclose`, 1e-7 |
| 231 | + rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 |
| 232 | + else: |
| 233 | + rtol = 1e-7 |
| 234 | + |
| 235 | + actual_np = as_numpy_array(actual, xp=xp) |
| 236 | + desired_np = as_numpy_array(desired, xp=xp) |
| 237 | + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] |
| 238 | + actual_np, |
| 239 | + desired_np, |
| 240 | + rtol=rtol, # pyright: ignore[reportArgumentType] |
236 | 241 | atol=atol,
|
237 | 242 | err_msg=err_msg,
|
238 | 243 | )
|
|
0 commit comments