Skip to content

Commit d064ae7

Browse files
committed
Full test coverage for Buffer.__dlpack__, __dlpack_device__
1 parent 29c9fb8 commit d064ae7

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

cuda_core/cuda/core/experimental/_memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,14 @@ def __dlpack__(
179179
# Note: we ignore the stream argument entirely (as if it is -1).
180180
# It is the user's responsibility to maintain stream order.
181181
if dl_device is not None:
182-
raise BufferError("Sorry not supported: dl_device other than None")
182+
raise BufferError("Sorry, not supported: dl_device other than None")
183183
if copy is True:
184-
raise BufferError("Sorry not supported: copy=True")
184+
raise BufferError("Sorry, not supported: copy=True")
185185
if max_version is None:
186186
versioned = False
187187
else:
188188
if not isinstance(max_version, tuple) or len(max_version) != 2:
189-
raise RuntimeError(f"Expected max_version Tuple[int, int], got {max_version}")
189+
raise BufferError(f"Expected max_version Tuple[int, int], got {max_version}")
190190
versioned = max_version >= (1, 0)
191191
capsule = make_py_capsule(self, versioned)
192192
return capsule

cuda_core/tests/test_memory.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313

1414
import ctypes
1515

16+
import pytest
17+
1618
from cuda.core.experimental import Device
17-
from cuda.core.experimental._memory import Buffer, MemoryResource
19+
from cuda.core.experimental._memory import Buffer, DLDeviceType, MemoryResource
1820
from cuda.core.experimental._utils import handle_return
1921

2022

@@ -116,6 +118,12 @@ def device_id(self) -> int:
116118
raise RuntimeError("the pinned memory resource is not bound to any GPU")
117119

118120

121+
class NullMemoryResource(DummyHostMemoryResource):
122+
@property
123+
def is_host_accessible(self) -> bool:
124+
return False
125+
126+
119127
def buffer_initialization(dummy_mr: MemoryResource):
120128
buffer = dummy_mr.allocate(size=1024)
121129
assert buffer.handle != 0
@@ -211,3 +219,46 @@ def test_buffer_close():
211219
buffer_close(DummyHostMemoryResource())
212220
buffer_close(DummyUnifiedMemoryResource(device))
213221
buffer_close(DummyPinnedMemoryResource(device))
222+
223+
224+
def test_buffer_dunder_dlpack():
225+
device = Device()
226+
device.set_current()
227+
dummy_mr = DummyDeviceMemoryResource(device)
228+
buffer = dummy_mr.allocate(size=1024)
229+
capsule = buffer.__dlpack__()
230+
assert "dltensor" in repr(capsule)
231+
capsule = buffer.__dlpack__(max_version=(1, 0))
232+
assert "dltensor" in repr(capsule)
233+
with pytest.raises(BufferError, match=r"^Sorry, not supported: dl_device other than None$"):
234+
buffer.__dlpack__(dl_device=[])
235+
with pytest.raises(BufferError, match=r"^Sorry, not supported: copy=True$"):
236+
buffer.__dlpack__(copy=True)
237+
with pytest.raises(BufferError, match=r"^Expected max_version Tuple\[int, int\], got \[\]$"):
238+
buffer.__dlpack__(max_version=[])
239+
with pytest.raises(BufferError, match=r"^Expected max_version Tuple\[int, int\], got \(9, 8, 7\)$"):
240+
buffer.__dlpack__(max_version=(9, 8, 7))
241+
242+
243+
@pytest.mark.parametrize(
244+
("DummyMR", "expected"),
245+
[
246+
(DummyDeviceMemoryResource, (DLDeviceType.kDLCUDA, 0)),
247+
(DummyHostMemoryResource, (DLDeviceType.kDLCPU, 0)),
248+
(DummyUnifiedMemoryResource, (DLDeviceType.kDLCUDAHost, 0)),
249+
(DummyPinnedMemoryResource, (DLDeviceType.kDLCUDAHost, 0)),
250+
],
251+
)
252+
def test_buffer_dunder_dlpack_device_success(DummyMR, expected):
253+
device = Device()
254+
device.set_current()
255+
dummy_mr = DummyMR() if DummyMR is DummyHostMemoryResource else DummyMR(device)
256+
buffer = dummy_mr.allocate(size=1024)
257+
assert buffer.__dlpack_device__() == expected
258+
259+
260+
def test_buffer_dunder_dlpack_device_failure():
261+
dummy_mr = NullMemoryResource()
262+
buffer = dummy_mr.allocate(size=1024)
263+
with pytest.raises(BufferError, match=r"^buffer is neither device-accessible nor host-accessible$"):
264+
buffer.__dlpack_device__()

0 commit comments

Comments
 (0)