Skip to content

Commit 1ac5288

Browse files
committed
Add 2023.12 device and copy keywords to from_dlpack
The copy keyword just raises NotImplementedError for now.
1 parent 3fde5dd commit 1ac5288

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

array_api_strict/_creation_functions.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SupportsBufferProtocol,
1313
)
1414
from ._dtypes import _DType, _all_dtypes
15+
from ._flags import get_array_api_strict_flags
1516

1617
import numpy as np
1718

@@ -174,19 +175,38 @@ def eye(
174175
175176
See its docstring for more information.
176177
"""
177-
from ._array_object import Array, CPU_DEVICE
178+
from ._array_object import Array
178179

179180
_check_valid_dtype(dtype)
180-
if device not in [CPU_DEVICE, None]:
181-
raise ValueError(f"Unsupported device {device!r}")
181+
_check_device(device)
182+
182183
if dtype is not None:
183184
dtype = dtype._np_dtype
184185
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
185186

186187

187-
def from_dlpack(x: object, /) -> Array:
188+
_default = object()
189+
190+
def from_dlpack(
191+
x: object,
192+
/,
193+
*,
194+
device: Optional[Device] = _default,
195+
copy: Optional[bool] = _default,
196+
) -> Array:
188197
from ._array_object import Array
189198

199+
if get_array_api_strict_flags()['api_version'] < '2023.12':
200+
if device is not _default:
201+
raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API")
202+
if copy is not _default:
203+
raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API")
204+
205+
if device is not _default:
206+
_check_device(device)
207+
if copy not in [_default, None]:
208+
raise NotImplementedError("The copy argument to from_dlpack is not yet implemented")
209+
190210
return Array._new(np.from_dlpack(x))
191211

192212

array_api_strict/_data_type_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def astype(
3636
if get_array_api_strict_flags()['api_version'] >= '2023.12':
3737
_check_device(device)
3838
else:
39-
raise TypeError("The device argument to astype requires the 2023.12 version of the array API")
39+
raise TypeError("The device argument to astype requires at least version 2023.12 of the array API")
4040

4141
if not copy and dtype == x.dtype:
4242
return x

0 commit comments

Comments
 (0)