@@ -51,6 +51,8 @@ def __repr__(self):
51
51
52
52
CPU_DEVICE = _cpu_device ()
53
53
54
+ _default = object ()
55
+
54
56
class Array :
55
57
"""
56
58
n-d array object for the array API namespace.
@@ -525,10 +527,34 @@ def __complex__(self: Array, /) -> complex:
525
527
res = self ._array .__complex__ ()
526
528
return res
527
529
528
- def __dlpack__ (self : Array , / , * , stream : None = None ) -> PyCapsule :
530
+ def __dlpack__ (
531
+ self : Array ,
532
+ / ,
533
+ * ,
534
+ stream : Optional [Union [int , Any ]] = None ,
535
+ max_version : Optional [tuple [int , int ]] = _default ,
536
+ dl_device : Optional [tuple [IntEnum , int ]] = _default ,
537
+ copy : Optional [bool ] = _default ,
538
+ ) -> PyCapsule :
529
539
"""
530
540
Performs the operation __dlpack__.
531
541
"""
542
+ if get_array_api_strict_flags ()['api_version' ] < '2023.12' :
543
+ if max_version is not _default :
544
+ raise ValueError ("The max_version argument to __dlpack__ requires at least version 2023.12 of the array API" )
545
+ if dl_device is not _default :
546
+ raise ValueError ("The device argument to __dlpack__ requires at least version 2023.12 of the array API" )
547
+ if copy is not _default :
548
+ raise ValueError ("The copy argument to __dlpack__ requires at least version 2023.12 of the array API" )
549
+
550
+ # Going to wait for upstream numpy support
551
+ if max_version not in [_default , None ]:
552
+ raise NotImplementedError ("The max_version argument to __dlpack__ is not yet implemented" )
553
+ if dl_device not in [_default , None ]:
554
+ raise NotImplementedError ("The device argument to __dlpack__ is not yet implemented" )
555
+ if copy not in [_default , None ]:
556
+ raise NotImplementedError ("The copy argument to __dlpack__ is not yet implemented" )
557
+
532
558
return self ._array .__dlpack__ (stream = stream )
533
559
534
560
def __dlpack_device__ (self : Array , / ) -> Tuple [IntEnum , int ]:
0 commit comments