Skip to content

Commit b55b8ba

Browse files
committed
add docs
1 parent 8ce0aa6 commit b55b8ba

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,59 @@ from cuda.core.experimental._utils import handle_return
2020

2121
@cython.dataclasses.dataclass
2222
cdef class StridedMemoryView:
23-
23+
"""A dataclass holding metadata of a strided dense array/tensor.
24+
25+
A :obj:`StridedMemoryView` instance can be created in two ways:
26+
27+
1. Using the :obj:`viewable` decorator (recommended)
28+
2. Explicit construction, see below
29+
30+
This object supports both DLPack (up to v1.0) and CUDA Array Interface
31+
(CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol
32+
first, then the CAI protocol. A :obj:`BufferError` is raised if neither is
33+
supported.
34+
35+
Since either way would take a consumer stream, for DLPack it is passed to
36+
``obj.__dlpack__()`` as-is (except for :obj:`None`, see below); for CAI, a
37+
stream order will be established between the consumer stream and the
38+
producer stream (from ``obj.__cuda_array_interface__()["stream"]``), as if
39+
``cudaStreamWaitEvent`` is called by this method.
40+
41+
To opt-out of the stream ordering operation in either DLPack or CAI,
42+
please pass ``stream_ptr=-1``. Note that this deviates (on purpose)
43+
from the semantics of ``obj.__dlpack__(stream=None, ...)`` since ``cuda.core``
44+
does not encourage using the (legacy) default/null stream, but is
45+
consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be
46+
internally passed to ``obj.__dlpack__()`` instead.
47+
48+
Attributes
49+
----------
50+
ptr : int
51+
Pointer to the tensor buffer (as a Python `int`).
52+
shape: tuple
53+
Shape of the tensor.
54+
strides: tuple
55+
Strides of the tensor (in **counts**, not bytes).
56+
dtype: numpy.dtype
57+
Data type of the tensor.
58+
device_id: int
59+
The device ID for where the tensor is located. It is 0 for CPU tensors.
60+
device_accessible: bool
61+
Whether the tensor data can be accessed on the GPU.
62+
readonly: bool
63+
Whether the tensor data can be modified in place.
64+
exporting_obj: Any
65+
A reference to the original tensor object that is being viewed.
66+
67+
Parameters
68+
----------
69+
obj : Any
70+
Any objects that supports either DLPack (up to v1.0) or CUDA Array
71+
Interface (v3).
72+
stream_ptr: int
73+
The pointer address (as Python `int`) to the **consumer** stream.
74+
Stream ordering will be properly established unless ``-1`` is passed.
75+
"""
2476
# TODO: switch to use Cython's cdef typing?
2577
ptr: int = None
2678
shape: tuple = None
@@ -285,6 +337,33 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
285337

286338

287339
def viewable(tuple arg_indices):
340+
"""Decorator to create proxy objects to :obj:`StridedMemoryView` for the
341+
specified positional arguments.
342+
343+
Inside the decorated function, the specified arguments becomes instances
344+
of an (undocumented) proxy type, regardless of its original source. A
345+
:obj:`StridedMemoryView` instance can be obtained by passing the (consumer)
346+
stream pointer (as a Python `int`) to the proxies's ``view()`` method. For
347+
example:
348+
349+
.. code-block:: python
350+
351+
@viewable((1,))
352+
def my_func(arg0, arg1, arg2, stream: Stream):
353+
# arg1 can be any object supporting DLPack or CUDA Array Interface
354+
view = arg1.view(stream.handle)
355+
assert isinstance(view, StridedMemoryView)
356+
...
357+
358+
This allows array/tensor attributes to be accessed inside the function
359+
implementation, while keeping the function body array-library-agnostic (if
360+
desired).
361+
362+
Parameters
363+
----------
364+
arg_indices : tuple
365+
The indices of the target positional arguments.
366+
"""
288367
def wrapped_func_with_indices(func):
289368
@functools.wraps(func)
290369
def wrapped_func(*args, **kwargs):

cuda_core/docs/source/api.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ CUDA compilation toolchain
3131
:toctree: generated/
3232

3333
Program
34+
35+
36+
.. module:: cuda.core.experimental.utils
37+
38+
Utility functions
39+
-----------------
40+
41+
.. autosummary::
42+
:toctree: generated/
43+
44+
viewable
45+
46+
:template: dataclass.rst
47+
48+
StridedMemoryView

cuda_core/docs/source/conf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
'sphinx.ext.autodoc',
3434
'sphinx.ext.autosummary',
3535
'sphinx.ext.napoleon',
36+
'sphinx.ext.intersphinx',
3637
'myst_nb',
3738
'enum_tools.autoenum',
3839
'sphinx_copybutton',
@@ -81,3 +82,11 @@
8182

8283
# skip cmdline prompts
8384
copybutton_exclude = '.linenos, .gp'
85+
86+
intersphinx_mapping = {
87+
'python': ('https://docs.python.org/3/', None),
88+
'numpy': ('https://numpy.org/doc/stable/', None),
89+
}
90+
91+
napoleon_google_docstring = False
92+
napoleon_numpy_docstring = True

cuda_core/tests/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
15
try:
26
import cupy as cp
37
except ImportError:
@@ -30,12 +34,14 @@ def test_viewable_cpu(self, in_arr):
3034

3135
@viewable((0,))
3236
def my_func(arr):
37+
# stream_ptr=-1 means "the consumer does not care"
3338
view = arr.view(-1)
3439
self._check_view(view, in_arr)
3540

3641
my_func(in_arr)
3742

3843
def test_strided_memory_view_cpu(self, in_arr):
44+
# stream_ptr=-1 means "the consumer does not care"
3945
view = StridedMemoryView(in_arr, stream_ptr=-1)
4046
self._check_view(view, in_arr)
4147

@@ -93,6 +99,7 @@ def test_viewable_gpu(self, in_arr, stream):
9399
# TODO: use the device fixture?
94100
dev = Device()
95101
dev.set_current()
102+
# This is the consumer stream
96103
s = dev.create_stream() if stream else None
97104

98105
@viewable((0,))
@@ -106,6 +113,7 @@ def test_strided_memory_view_cpu(self, in_arr, stream):
106113
# TODO: use the device fixture?
107114
dev = Device()
108115
dev.set_current()
116+
# This is the consumer stream
109117
s = dev.create_stream() if stream else None
110118

111119
view = StridedMemoryView(

0 commit comments

Comments
 (0)