Skip to content

Commit 500b1a0

Browse files
committed
Add versioning support to DLPack APIs
xref dmlc/dlpack#116
1 parent 3d91878 commit 500b1a0

File tree

2 files changed

+68
-21
lines changed

2 files changed

+68
-21
lines changed

src/array_api_stubs/_draft/array_object.py

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ def __complex__(self: array, /) -> complex:
278278
"""
279279

280280
def __dlpack__(
281-
self: array, /, *, stream: Optional[Union[int, Any]] = None
281+
self: array, /, *,
282+
max_version: Optional[tuple[int, int]] = None,
283+
stream: Optional[Union[int, Any]] = None
282284
) -> PyCapsule:
283285
"""
284286
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
@@ -287,46 +289,42 @@ def __dlpack__(
287289
----------
288290
self: array
289291
array instance.
292+
max_version: Optional[tuple[int, int]]
293+
The maximum DLPack version that the consumer (i.e., the caller of
294+
``__dlpack__``) supports, in the form ``(major, minor)``.
295+
This method may return that maximum version (recommended if it does
296+
support that), or a different version.
290297
stream: Optional[Union[int, Any]]
291298
for CUDA and ROCm, a Python integer representing a pointer to a stream, on devices that support streams. ``stream`` is provided by the consumer to the producer to instruct the producer to ensure that operations can safely be performed on the array (e.g., by inserting a dependency between streams via "wait for event"). The pointer must be a positive integer or ``-1``. If ``stream`` is ``-1``, the value may be used by the consumer to signal "producer must not perform any synchronization". The ownership of the stream stays with the consumer. On CPU and other device types without streams, only ``None`` is accepted.
292299
293300
For other device types which do have a stream, queue or similar synchronization mechanism, the most appropriate type to use for ``stream`` is not yet determined. E.g., for SYCL one may want to use an object containing an in-order ``cl::sycl::queue``. This is allowed when libraries agree on such a convention, and may be standardized in a future version of this API standard.
294301
302+
.. note::
303+
Support for a ``stream`` value other than ``None`` is optional and implementation-dependent.
295304
296-
.. note::
297-
Support for a ``stream`` value other than ``None`` is optional and implementation-dependent.
298-
299-
300-
Device-specific notes:
301-
302-
303-
.. admonition:: CUDA
304-
:class: note
305+
Device-specific values of ``stream`` for CUDA:
305306
306307
- ``None``: producer must assume the legacy default stream (default).
307308
- ``1``: the legacy default stream.
308309
- ``2``: the per-thread default stream.
309310
- ``> 2``: stream number represented as a Python integer.
310311
- ``0`` is disallowed due to its ambiguity: ``0`` could mean either ``None``, ``1``, or ``2``.
311312
312-
313-
.. admonition:: ROCm
314-
:class: note
313+
Device-specific values of ``stream`` for ROCm:
315314
316315
- ``None``: producer must assume the legacy default stream (default).
317316
- ``0``: the default stream.
318317
- ``> 2``: stream number represented as a Python integer.
319318
- Using ``1`` and ``2`` is not supported.
320319
320+
.. admonition:: Tip
321+
:class: important
321322
322-
.. admonition:: Tip
323-
:class: important
324-
325-
It is recommended that implementers explicitly handle streams. If
326-
they use the legacy default stream, specifying ``1`` (CUDA) or ``0``
327-
(ROCm) is preferred. ``None`` is a safe default for developers who do
328-
not want to think about stream handling at all, potentially at the
329-
cost of more synchronization than necessary.
323+
It is recommended that implementers explicitly handle streams. If
324+
they use the legacy default stream, specifying ``1`` (CUDA) or ``0``
325+
(ROCm) is preferred. ``None`` is a safe default for developers who do
326+
not want to think about stream handling at all, potentially at the
327+
cost of more synchronization than necessary.
330328
331329
Returns
332330
-------
@@ -343,9 +341,52 @@ def __dlpack__(
343341
344342
Notes
345343
-----
344+
Major DLPack versions represent ABI breaks, minor versions represent
345+
ABI-compatible additions (e.g., new enum values for new data types or
346+
device types).
347+
348+
The ``max_version`` keyword was introduced in v2023.12, and goes
349+
together with the ``DLManagedTensorVersioned`` struct added in DLPack
350+
1.0. This keyword may not be used by consumers for some time after
351+
introduction. It is recommended to use this logic in the implementation
352+
of ``__dlpack__``:
353+
354+
.. code:: python
355+
356+
if max_version is None:
357+
# Keep and use the DLPack 0.X implementation
358+
# Note: in >= 2 years from now (but ideally as late as
359+
# possible), it's okay to raise BufferError here
360+
else:
361+
# We get to produce `DLManagedTensorVersioned` now
362+
if max_version >= our_own_dlpack_version:
363+
# Consumer understands us, just return a Capsule with our max version
364+
elif max_version[0] == our_own_dlpack_version[0]:
365+
# major versions match, we should still be fine here -
366+
# return our own max version
367+
else:
368+
# if we're at a higher major version internally, did we
369+
# keep an implementation of the older major version around?
370+
# If so, use that. Else, just return our max
371+
# version and let the consumer deal with it.
372+
373+
And this logic for the producer (i.e., in ``from_dlpack``):
374+
375+
.. code:: python
376+
377+
try:
378+
x.__dlpack__(max_version=(1, 0))
379+
# if it succeeds, store info about capsule name being "dltensor_versioned",
380+
# and needing to set the capsule name to "used_dltensor_versioned"
381+
# when we're done
382+
except TypeError:
383+
x.__dlpack__()
346384
347385
.. versionchanged:: 2022.12
348386
Added BufferError.
387+
388+
.. versionchanged:: 2023.12
389+
Added the ``max_version`` keyword.
349390
"""
350391

351392
def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:

src/array_api_stubs/_draft/creation_functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ def from_dlpack(x: object, /) -> array:
212212
:class: note
213213
214214
The returned array may be either a copy or a view. See :ref:`data-interchange` for details.
215+
216+
Notes
217+
-----
218+
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
219+
order to handle DLPack versioning correctly.
220+
215221
"""
216222

217223

0 commit comments

Comments
 (0)