Skip to content

Commit 3fde5dd

Browse files
committed
Factor out device checks into a helper function
1 parent 6b43194 commit 3fde5dd

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

array_api_strict/_creation_functions.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def _supports_buffer_protocol(obj):
2828
return False
2929
return True
3030

31+
def _check_device(device):
32+
# _array_object imports in this file are inside the functions to avoid
33+
# circular imports
34+
from ._array_object import CPU_DEVICE
35+
36+
if device not in [CPU_DEVICE, None]:
37+
raise ValueError(f"Unsupported device {device!r}")
38+
3139
def asarray(
3240
obj: Union[
3341
Array,
@@ -48,16 +56,13 @@ def asarray(
4856
4957
See its docstring for more information.
5058
"""
51-
# _array_object imports in this file are inside the functions to avoid
52-
# circular imports
53-
from ._array_object import Array, CPU_DEVICE
59+
from ._array_object import Array
5460

5561
_check_valid_dtype(dtype)
5662
_np_dtype = None
5763
if dtype is not None:
5864
_np_dtype = dtype._np_dtype
59-
if device not in [CPU_DEVICE, None]:
60-
raise ValueError(f"Unsupported device {device!r}")
65+
_check_device(device)
6166

6267
if np.__version__[0] < '2':
6368
if copy is False:
@@ -106,11 +111,11 @@ def arange(
106111
107112
See its docstring for more information.
108113
"""
109-
from ._array_object import Array, CPU_DEVICE
114+
from ._array_object import Array
110115

111116
_check_valid_dtype(dtype)
112-
if device not in [CPU_DEVICE, None]:
113-
raise ValueError(f"Unsupported device {device!r}")
117+
_check_device(device)
118+
114119
if dtype is not None:
115120
dtype = dtype._np_dtype
116121
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
@@ -127,11 +132,11 @@ def empty(
127132
128133
See its docstring for more information.
129134
"""
130-
from ._array_object import Array, CPU_DEVICE
135+
from ._array_object import Array
131136

132137
_check_valid_dtype(dtype)
133-
if device not in [CPU_DEVICE, None]:
134-
raise ValueError(f"Unsupported device {device!r}")
138+
_check_device(device)
139+
135140
if dtype is not None:
136141
dtype = dtype._np_dtype
137142
return Array._new(np.empty(shape, dtype=dtype))
@@ -145,11 +150,11 @@ def empty_like(
145150
146151
See its docstring for more information.
147152
"""
148-
from ._array_object import Array, CPU_DEVICE
153+
from ._array_object import Array
149154

150155
_check_valid_dtype(dtype)
151-
if device not in [CPU_DEVICE, None]:
152-
raise ValueError(f"Unsupported device {device!r}")
156+
_check_device(device)
157+
153158
if dtype is not None:
154159
dtype = dtype._np_dtype
155160
return Array._new(np.empty_like(x._array, dtype=dtype))
@@ -197,11 +202,11 @@ def full(
197202
198203
See its docstring for more information.
199204
"""
200-
from ._array_object import Array, CPU_DEVICE
205+
from ._array_object import Array
201206

202207
_check_valid_dtype(dtype)
203-
if device not in [CPU_DEVICE, None]:
204-
raise ValueError(f"Unsupported device {device!r}")
208+
_check_device(device)
209+
205210
if isinstance(fill_value, Array) and fill_value.ndim == 0:
206211
fill_value = fill_value._array
207212
if dtype is not None:
@@ -227,11 +232,11 @@ def full_like(
227232
228233
See its docstring for more information.
229234
"""
230-
from ._array_object import Array, CPU_DEVICE
235+
from ._array_object import Array
231236

232237
_check_valid_dtype(dtype)
233-
if device not in [CPU_DEVICE, None]:
234-
raise ValueError(f"Unsupported device {device!r}")
238+
_check_device(device)
239+
235240
if dtype is not None:
236241
dtype = dtype._np_dtype
237242
res = np.full_like(x._array, fill_value, dtype=dtype)
@@ -257,11 +262,11 @@ def linspace(
257262
258263
See its docstring for more information.
259264
"""
260-
from ._array_object import Array, CPU_DEVICE
265+
from ._array_object import Array
261266

262267
_check_valid_dtype(dtype)
263-
if device not in [CPU_DEVICE, None]:
264-
raise ValueError(f"Unsupported device {device!r}")
268+
_check_device(device)
269+
265270
if dtype is not None:
266271
dtype = dtype._np_dtype
267272
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
@@ -298,11 +303,11 @@ def ones(
298303
299304
See its docstring for more information.
300305
"""
301-
from ._array_object import Array, CPU_DEVICE
306+
from ._array_object import Array
302307

303308
_check_valid_dtype(dtype)
304-
if device not in [CPU_DEVICE, None]:
305-
raise ValueError(f"Unsupported device {device!r}")
309+
_check_device(device)
310+
306311
if dtype is not None:
307312
dtype = dtype._np_dtype
308313
return Array._new(np.ones(shape, dtype=dtype))
@@ -316,11 +321,11 @@ def ones_like(
316321
317322
See its docstring for more information.
318323
"""
319-
from ._array_object import Array, CPU_DEVICE
324+
from ._array_object import Array
320325

321326
_check_valid_dtype(dtype)
322-
if device not in [CPU_DEVICE, None]:
323-
raise ValueError(f"Unsupported device {device!r}")
327+
_check_device(device)
328+
324329
if dtype is not None:
325330
dtype = dtype._np_dtype
326331
return Array._new(np.ones_like(x._array, dtype=dtype))
@@ -365,11 +370,11 @@ def zeros(
365370
366371
See its docstring for more information.
367372
"""
368-
from ._array_object import Array, CPU_DEVICE
373+
from ._array_object import Array
369374

370375
_check_valid_dtype(dtype)
371-
if device not in [CPU_DEVICE, None]:
372-
raise ValueError(f"Unsupported device {device!r}")
376+
_check_device(device)
377+
373378
if dtype is not None:
374379
dtype = dtype._np_dtype
375380
return Array._new(np.zeros(shape, dtype=dtype))
@@ -383,11 +388,11 @@ def zeros_like(
383388
384389
See its docstring for more information.
385390
"""
386-
from ._array_object import Array, CPU_DEVICE
391+
from ._array_object import Array
387392

388393
_check_valid_dtype(dtype)
389-
if device not in [CPU_DEVICE, None]:
390-
raise ValueError(f"Unsupported device {device!r}")
394+
_check_device(device)
395+
391396
if dtype is not None:
392397
dtype = dtype._np_dtype
393398
return Array._new(np.zeros_like(x._array, dtype=dtype))

array_api_strict/_data_type_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array, CPU_DEVICE
3+
from ._array_object import Array
4+
from ._creation_functions import _check_device
45
from ._dtypes import (
56
_DType,
67
_all_dtypes,
@@ -33,8 +34,7 @@ def astype(
3334
) -> Array:
3435
if device is not _default:
3536
if get_array_api_strict_flags()['api_version'] >= '2023.12':
36-
if device not in [CPU_DEVICE, None]:
37-
raise ValueError(f"Unsupported device {device!r}")
37+
_check_device(device)
3838
else:
3939
raise TypeError("The device argument to astype requires the 2023.12 version of the array API")
4040

0 commit comments

Comments
 (0)