Skip to content

Support copy and device keywords in from_dlpack #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 14, 2024
22 changes: 19 additions & 3 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __dlpack__(
*,
stream: Optional[Union[int, Any]] = None,
max_version: Optional[tuple[int, int]] = None,
dl_device: Optional[Tuple[Enum, int]] = None,
copy: Optional[bool] = None
) -> PyCapsule:
"""
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
Expand Down Expand Up @@ -333,12 +335,24 @@ def __dlpack__(
not want to think about stream handling at all, potentially at the
cost of more synchronizations than necessary.
max_version: Optional[tuple[int, int]]
The maximum DLPack version that the *consumer* (i.e., the caller of
the maximum DLPack version that the *consumer* (i.e., the caller of
``__dlpack__``) supports, in the form of a 2-tuple ``(major, minor)``.
This method may return a capsule of version ``max_version`` (recommended
if it does support that), or of a different version.
This means the consumer must verify the version even when
`max_version` is passed.
dl_device: Optional[Tuple[Enum, int]]
the DLPack device type. Default is ``None``, meaning the exported capsule
should be on the same device as ``self`` is. When specified, the format
must follow that of the return value of :meth:`array.__dlpack_device__`.
If the device type cannot be handled by the producer, this function must
raise `BufferError`.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the
function must always copy (paerformed by the producer), potentially allowing
data movement across the library (and/or device) boundary. If ``False``,
the function must never copy. If ``None``, the function must reuse existing
memory buffer if possible and copy otherwise. Default: ``None``.

Returns
-------
Expand Down Expand Up @@ -394,7 +408,7 @@ def __dlpack__(
# here to tell users that the consumer's max_version is too
# old to allow the data exchange to happen.

And this logic for the consumer in ``from_dlpack``:
And this logic for the consumer in :func:`~array_api.from_dlpack`:

.. code:: python

Expand All @@ -409,7 +423,7 @@ def __dlpack__(
Added BufferError.

.. versionchanged:: 2023.12
Added the ``max_version`` keyword.
Added the ``max_version``, ``dl_device``, and ``copy`` keywords.
"""

def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
Expand All @@ -436,6 +450,8 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
METAL = 8
VPI = 9
ROCM = 10
CUDA_MANAGED = 13
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
Expand Down
33 changes: 28 additions & 5 deletions src/array_api_stubs/_draft/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


from ._types import (
Any,
List,
NestedSequence,
Optional,
Expand Down Expand Up @@ -214,19 +215,38 @@ def eye(
"""


def from_dlpack(x: object, /) -> array:
def from_dlpack(
x: object, /, *,
device: Optional[device] = None,
copy: Optional[bool] = None
) -> Union[array, Any]:
"""
Returns a new array containing the data from another (array) object with a ``__dlpack__`` method.

Parameters
----------
x: object
input (array) object.
device: Optional[device]
device on which to place the created array. If ``device`` is ``None`` and ``x`` supports DLPack, the output array device must be inferred from ``x``. Default: ``None``.

The v2023.12 standard only mandates that a compliant library must offer a way for ``from_dlpack`` to create an array on CPU (using
the library-chosen way to represent the CPU device - ``kDLCPU`` in DLPack - e.g. a ``"CPU"`` string or a ``Device("CPU")`` object).
If the compliant library does not support the CPU device and needs to outsource to another (compliant) array library, it may do so
with a clear user documentation and/or run-time warning. If a copy must be made to enable this, and ``copy`` is set to ``False``,
the function must raise ``ValueError``.

Other kinds of devices will be considered for standardization in a future version.
copy: Optional[bool]
boolean indicating whether or not to copy the input. If ``True``, the function must always copy. If ``False``, the function must never copy and must raise a ``BufferError`` in case a copy would be necessary (e.g. the producer disallows views). If ``None``, the function must reuse existing memory buffer if possible and copy otherwise. Default: ``None``.

If a copy is needed, the stream over which the copy is performed must be taken from the consumer, following the DLPack protocol (see :meth:`array.__dlpack__`).

Returns
-------
out: array
an array containing the data in `x`.
out: Union[array, Any]
an array containing the data in ``x``. In the case that the compliant library does not support the given ``device`` out of box
and must oursource to another (compliant) library, the output will be that library's compliant array object.

.. admonition:: Note
:class: note
Expand All @@ -238,9 +258,9 @@ def from_dlpack(x: object, /) -> array:
BufferError
The ``__dlpack__`` and ``__dlpack_device__`` methods on the input array
may raise ``BufferError`` when the data cannot be exported as DLPack
(e.g., incompatible dtype or strides). It may also raise other errors
(e.g., incompatible dtype, strides, or device). It may also raise other errors
when export fails for other reasons (e.g., not enough memory available
to materialize the data). ``from_dlpack`` must propagate such
to materialize the data, a copy must made, etc). ``from_dlpack`` must propagate such
exceptions.
AttributeError
If the ``__dlpack__`` and ``__dlpack_device__`` methods are not present
Expand All @@ -251,6 +271,9 @@ def from_dlpack(x: object, /) -> array:
-----
See :meth:`array.__dlpack__` for implementation suggestions for `from_dlpack` in
order to handle DLPack versioning correctly.

.. versionchanged:: 2023.12
Added device and copy support.
"""


Expand Down