Skip to content

Commit dc4684b

Browse files
committed
Update the signature of __dlpack__ for 2023.12
The new arguments are not actually supported yet, and probably won't be until upstream NumPy does.
1 parent 1ac5288 commit dc4684b

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

array_api_strict/_array_object.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __repr__(self):
5151

5252
CPU_DEVICE = _cpu_device()
5353

54+
_default = object()
55+
5456
class Array:
5557
"""
5658
n-d array object for the array API namespace.
@@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
525527
res = self._array.__complex__()
526528
return res
527529

528-
def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
530+
def __dlpack__(
531+
self: Array,
532+
/,
533+
*,
534+
stream: Optional[Union[int, Any]] = None,
535+
max_version: Optional[tuple[int, int]] = _default,
536+
dl_device: Optional[tuple[IntEnum, int]] = _default,
537+
copy: Optional[bool] = _default,
538+
) -> PyCapsule:
529539
"""
530540
Performs the operation __dlpack__.
531541
"""
542+
if get_array_api_strict_flags()['api_version'] < '2023.12':
543+
if max_version is not _default:
544+
raise ValueError("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API")
545+
if dl_device is not _default:
546+
raise ValueError("The device argument to __dlpack__ requires at least version 2023.12 of the array API")
547+
if copy is not _default:
548+
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
549+
550+
# Going to wait for upstream numpy support
551+
if max_version not in [_default, None]:
552+
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
553+
if dl_device not in [_default, None]:
554+
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
555+
if copy not in [_default, None]:
556+
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
557+
532558
return self._array.__dlpack__(stream=stream)
533559

534560
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:

array_api_strict/_creation_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def from_dlpack(
202202
if copy is not _default:
203203
raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API")
204204

205+
# Going to wait for upstream numpy support
205206
if device is not _default:
206207
_check_device(device)
207208
if copy not in [_default, None]:

0 commit comments

Comments
 (0)