Skip to content

Commit cb3c2d6

Browse files
authored
Rework prepare_for_test (data-apis#2)
1 parent a6d2d8c commit cb3c2d6

File tree

2 files changed

+51
-39
lines changed

2 files changed

+51
-39
lines changed

src/array_api_extra/_lib/_testing.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import math
99
from types import ModuleType
10-
from typing import cast
10+
from typing import Any, cast
1111

1212
import numpy as np
1313
import pytest
@@ -17,13 +17,15 @@
1717
is_array_api_strict_namespace,
1818
is_cupy_namespace,
1919
is_dask_namespace,
20+
is_jax_namespace,
2021
is_numpy_namespace,
2122
is_pydata_sparse_namespace,
2223
is_torch_namespace,
24+
to_device,
2325
)
24-
from ._utils._typing import Array
26+
from ._utils._typing import Array, Device
2527

26-
__all__ = ["xp_assert_close", "xp_assert_equal"]
28+
__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"]
2729

2830

2931
def _check_ns_shape_dtype(
@@ -81,23 +83,28 @@ def _check_ns_shape_dtype(
8183
return desired_xp
8284

8385

84-
def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
86+
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
8587
"""
86-
Ensure that the array can be compared with np.testing.
87-
88-
This involves transferring it from GPU to CPU memory, densifying it, etc.
88+
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
8989
"""
90-
if is_torch_namespace(xp):
91-
return np.asarray(array.cpu()) # type: ignore[attr-defined, return-value] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType]
90+
if is_cupy_namespace(xp):
91+
return xp.asnumpy(array)
9292
if is_pydata_sparse_namespace(xp):
9393
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
94+
95+
if is_torch_namespace(xp):
96+
array = to_device(array, "cpu")
9497
if is_array_api_strict_namespace(xp):
95-
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
96-
# even if it is required by the standard as many backends don't support it
97-
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
98-
if is_cupy_namespace(xp):
99-
return xp.asnumpy(array)
100-
return array
98+
cpu: Device = xp.Device("CPU_DEVICE")
99+
array = to_device(array, cpu)
100+
if is_jax_namespace(xp):
101+
import jax
102+
103+
# Note: only needed if the transfer guard is enabled
104+
cpu = cast(Device, jax.devices("cpu")[0])
105+
array = to_device(array, cpu)
106+
107+
return np.asarray(array)
101108

102109

103110
def xp_assert_equal(
@@ -132,9 +139,9 @@ def xp_assert_equal(
132139
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
133140
"""
134141
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
135-
actual = _prepare_for_test(actual, xp)
136-
desired = _prepare_for_test(desired, xp)
137-
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
142+
actual_np = as_numpy_array(actual, xp=xp)
143+
desired_np = as_numpy_array(desired, xp=xp)
144+
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
138145

139146

140147
def xp_assert_less(
@@ -167,9 +174,9 @@ def xp_assert_less(
167174
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
168175
"""
169176
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
170-
x = _prepare_for_test(x, xp)
171-
y = _prepare_for_test(y, xp)
172-
np.testing.assert_array_less(x, y, err_msg=err_msg) # type: ignore[call-overload]
177+
x_np = as_numpy_array(x, xp=xp)
178+
y_np = as_numpy_array(y, xp=xp)
179+
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
173180

174181

175182
def xp_assert_close(
@@ -216,23 +223,21 @@ def xp_assert_close(
216223
"""
217224
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
218225

219-
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
220-
if rtol is None and floating:
221-
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
222-
# roughly half way between sqrt(eps) and the default for
223-
# `numpy.testing.assert_allclose`, 1e-7
224-
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
225-
elif rtol is None:
226-
rtol = 1e-7
227-
228-
actual = _prepare_for_test(actual, xp)
229-
desired = _prepare_for_test(desired, xp)
230-
231-
# JAX/Dask arrays work directly with `np.testing`
232-
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
233-
actual, # pyright: ignore[reportArgumentType]
234-
desired, # pyright: ignore[reportArgumentType]
235-
rtol=rtol,
226+
if rtol is None:
227+
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
228+
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
229+
# roughly half way between sqrt(eps) and the default for
230+
# `numpy.testing.assert_allclose`, 1e-7
231+
rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4
232+
else:
233+
rtol = 1e-7
234+
235+
actual_np = as_numpy_array(actual, xp=xp)
236+
desired_np = as_numpy_array(desired, xp=xp)
237+
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
238+
actual_np,
239+
desired_np,
240+
rtol=rtol, # pyright: ignore[reportArgumentType]
236241
atol=atol,
237242
err_msg=err_msg,
238243
)

tests/test_testing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from array_api_extra._lib._backends import Backend
1010
from array_api_extra._lib._testing import (
11+
as_numpy_array,
1112
xp_assert_close,
1213
xp_assert_equal,
1314
xp_assert_less,
@@ -17,7 +18,7 @@
1718
is_dask_namespace,
1819
is_jax_namespace,
1920
)
20-
from array_api_extra._lib._utils._typing import Array
21+
from array_api_extra._lib._utils._typing import Array, Device
2122
from array_api_extra.testing import lazy_xp_function
2223

2324
# mypy: disable-error-code=decorated-any
@@ -38,6 +39,12 @@
3839
)
3940

4041

42+
def test_as_numpy_array(xp: ModuleType, device: Device):
43+
x = xp.asarray([1, 2, 3], device=device)
44+
y = as_numpy_array(x, xp=xp)
45+
assert isinstance(y, np.ndarray)
46+
47+
4148
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
4249
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
4350
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]

0 commit comments

Comments
 (0)