Skip to content

Commit 3bc8199

Browse files
committed
More multi-device support
1 parent 03e1ae7 commit 3bc8199

File tree

5 files changed

+50
-19
lines changed

5 files changed

+50
-19
lines changed

array_api_strict/_array_object.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __hash__(self):
6262

6363

6464
CPU_DEVICE = Device()
65+
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))
6566

6667
_default = object()
6768

array_api_strict/_creation_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ def _supports_buffer_protocol(obj):
3232
def _check_device(device):
3333
# _array_object imports in this file are inside the functions to avoid
3434
# circular imports
35-
from ._array_object import Device
35+
from ._array_object import Device, ALL_DEVICES
3636

3737
if device is not None and not isinstance(device, Device):
3838
raise ValueError(f"Unsupported device {device!r}")
3939

40+
if device not in ALL_DEVICES:
41+
raise ValueError(f"Unsupported device {device!r}")
42+
4043
def asarray(
4144
obj: Union[
4245
Array,

array_api_strict/_data_type_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def astype(
3737
_check_device(device)
3838
else:
3939
raise TypeError("The device argument to astype requires at least version 2023.12 of the array API")
40+
else:
41+
device = x.device
4042

4143
if not copy and dtype == x.dtype:
4244
return x
43-
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy))
45+
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device)
4446

4547

4648
def broadcast_arrays(*arrays: Array) -> List[Array]:

array_api_strict/_fft.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
float32,
1515
complex64,
1616
)
17-
from ._array_object import Array, CPU_DEVICE
17+
from ._array_object import Array, ALL_DEVICES
1818
from ._data_type_functions import astype
1919
from ._flags import requires_extension
2020

@@ -36,7 +36,7 @@ def fft(
3636
"""
3737
if x.dtype not in _complex_floating_dtypes:
3838
raise TypeError("Only complex floating-point dtypes are allowed in fft")
39-
res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm))
39+
res = Array._new(np.fft.fft(x._array, n=n, axis=axis, norm=norm), device=x.device)
4040
# Note: np.fft functions improperly upcast float32 and complex64 to
4141
# complex128
4242
if x.dtype == complex64:
@@ -59,7 +59,7 @@ def ifft(
5959
"""
6060
if x.dtype not in _complex_floating_dtypes:
6161
raise TypeError("Only complex floating-point dtypes are allowed in ifft")
62-
res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm))
62+
res = Array._new(np.fft.ifft(x._array, n=n, axis=axis, norm=norm), device=x.device)
6363
# Note: np.fft functions improperly upcast float32 and complex64 to
6464
# complex128
6565
if x.dtype == complex64:
@@ -82,7 +82,7 @@ def fftn(
8282
"""
8383
if x.dtype not in _complex_floating_dtypes:
8484
raise TypeError("Only complex floating-point dtypes are allowed in fftn")
85-
res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm))
85+
res = Array._new(np.fft.fftn(x._array, s=s, axes=axes, norm=norm), device=x.device)
8686
# Note: np.fft functions improperly upcast float32 and complex64 to
8787
# complex128
8888
if x.dtype == complex64:
@@ -105,7 +105,7 @@ def ifftn(
105105
"""
106106
if x.dtype not in _complex_floating_dtypes:
107107
raise TypeError("Only complex floating-point dtypes are allowed in ifftn")
108-
res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm))
108+
res = Array._new(np.fft.ifftn(x._array, s=s, axes=axes, norm=norm), device=x.device)
109109
# Note: np.fft functions improperly upcast float32 and complex64 to
110110
# complex128
111111
if x.dtype == complex64:
@@ -128,7 +128,7 @@ def rfft(
128128
"""
129129
if x.dtype not in _real_floating_dtypes:
130130
raise TypeError("Only real floating-point dtypes are allowed in rfft")
131-
res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm))
131+
res = Array._new(np.fft.rfft(x._array, n=n, axis=axis, norm=norm), device=x.device)
132132
# Note: np.fft functions improperly upcast float32 and complex64 to
133133
# complex128
134134
if x.dtype == float32:
@@ -151,7 +151,7 @@ def irfft(
151151
"""
152152
if x.dtype not in _complex_floating_dtypes:
153153
raise TypeError("Only complex floating-point dtypes are allowed in irfft")
154-
res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm))
154+
res = Array._new(np.fft.irfft(x._array, n=n, axis=axis, norm=norm), device=x.device)
155155
# Note: np.fft functions improperly upcast float32 and complex64 to
156156
# complex128
157157
if x.dtype == complex64:
@@ -174,7 +174,7 @@ def rfftn(
174174
"""
175175
if x.dtype not in _real_floating_dtypes:
176176
raise TypeError("Only real floating-point dtypes are allowed in rfftn")
177-
res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm))
177+
res = Array._new(np.fft.rfftn(x._array, s=s, axes=axes, norm=norm), device=x.device)
178178
# Note: np.fft functions improperly upcast float32 and complex64 to
179179
# complex128
180180
if x.dtype == float32:
@@ -197,7 +197,7 @@ def irfftn(
197197
"""
198198
if x.dtype not in _complex_floating_dtypes:
199199
raise TypeError("Only complex floating-point dtypes are allowed in irfftn")
200-
res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm))
200+
res = Array._new(np.fft.irfftn(x._array, s=s, axes=axes, norm=norm), device=x.device)
201201
# Note: np.fft functions improperly upcast float32 and complex64 to
202202
# complex128
203203
if x.dtype == complex64:
@@ -220,7 +220,7 @@ def hfft(
220220
"""
221221
if x.dtype not in _complex_floating_dtypes:
222222
raise TypeError("Only complex floating-point dtypes are allowed in hfft")
223-
res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm))
223+
res = Array._new(np.fft.hfft(x._array, n=n, axis=axis, norm=norm), device=x.device)
224224
# Note: np.fft functions improperly upcast float32 and complex64 to
225225
# complex128
226226
if x.dtype == complex64:
@@ -243,7 +243,7 @@ def ihfft(
243243
"""
244244
if x.dtype not in _real_floating_dtypes:
245245
raise TypeError("Only real floating-point dtypes are allowed in ihfft")
246-
res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm))
246+
res = Array._new(np.fft.ihfft(x._array, n=n, axis=axis, norm=norm), device=x.device)
247247
# Note: np.fft functions improperly upcast float32 and complex64 to
248248
# complex128
249249
if x.dtype == float32:
@@ -257,9 +257,9 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar
257257
258258
See its docstring for more information.
259259
"""
260-
if device not in [CPU_DEVICE, None]:
260+
if device not in ALL_DEVICES:
261261
raise ValueError(f"Unsupported device {device!r}")
262-
return Array._new(np.fft.fftfreq(n, d=d))
262+
return Array._new(np.fft.fftfreq(n, d=d), device=device)
263263

264264
@requires_extension('fft')
265265
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
@@ -268,9 +268,9 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A
268268
269269
See its docstring for more information.
270270
"""
271-
if device not in [CPU_DEVICE, None]:
271+
if device not in ALL_DEVICES:
272272
raise ValueError(f"Unsupported device {device!r}")
273-
return Array._new(np.fft.rfftfreq(n, d=d))
273+
return Array._new(np.fft.rfftfreq(n, d=d), device=device)
274274

275275
@requires_extension('fft')
276276
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
@@ -281,7 +281,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
281281
"""
282282
if x.dtype not in _floating_dtypes:
283283
raise TypeError("Only floating-point dtypes are allowed in fftshift")
284-
return Array._new(np.fft.fftshift(x._array, axes=axes))
284+
return Array._new(np.fft.fftshift(x._array, axes=axes), device=x.device)
285285

286286
@requires_extension('fft')
287287
def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
@@ -292,7 +292,7 @@ def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
292292
"""
293293
if x.dtype not in _floating_dtypes:
294294
raise TypeError("Only floating-point dtypes are allowed in ifftshift")
295-
return Array._new(np.fft.ifftshift(x._array, axes=axes))
295+
return Array._new(np.fft.ifftshift(x._array, axes=axes), device=x.device)
296296

297297
__all__ = [
298298
"fft",
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
import array_api_strict
4+
5+
6+
@pytest.mark.parametrize("func_name", ("fft", "ifft", "fftn", "ifftn", "irfft",
7+
"irfftn", "hfft", "fftshift", "ifftshift"))
8+
def test_fft_device_support_complex(func_name):
9+
func = getattr(array_api_strict.fft, func_name)
10+
x = array_api_strict.asarray([1, 2.],
11+
dtype=array_api_strict.complex64,
12+
device=array_api_strict.Device("device1"))
13+
y = func(x)
14+
15+
assert x.device == y.device
16+
17+
18+
@pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft"))
19+
def test_fft_device_support_real(func_name):
20+
func = getattr(array_api_strict.fft, func_name)
21+
x = array_api_strict.asarray([1, 2.],
22+
device=array_api_strict.Device("device1"))
23+
y = func(x)
24+
25+
assert x.device == y.device

0 commit comments

Comments
 (0)