Skip to content

Commit b59918b

Browse files
committed
bug fixes and test updates
1 parent 18a9c85 commit b59918b

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,11 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
177177
cdef object capsule
178178
try:
179179
capsule = obj.__dlpack__(
180-
stream=stream_ptr,
180+
stream=int(stream_ptr) if stream_ptr else None,
181181
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
182182
except TypeError:
183183
capsule = obj.__dlpack__(
184-
stream=stream_ptr)
184+
stream=int(stream_ptr) if stream_ptr else None)
185185

186186
cdef void* data = NULL
187187
if cpython.PyCapsule_IsValid(

cuda_core/tests/example_tests/test_basic_examples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import pytest
1515

16+
from cuda.core.experimental import Device
1617
from .utils import run_example
1718

1819
samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
@@ -23,3 +24,5 @@
2324
class TestExamples:
2425
def test_example(self, example, deinit_cuda):
2526
run_example(samples_path, example)
27+
if Device().device_id != 0:
28+
Device(0).set_current()

cuda_core/tests/test_stream.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cuda.core.experimental import Device, Stream, StreamOptions
1212
from cuda.core.experimental._event import Event
1313
from cuda.core.experimental._stream import LEGACY_DEFAULT_STREAM, PER_THREAD_DEFAULT_STREAM, default_stream
14+
from cuda.core.experimental._utils import driver
1415

1516

1617
def test_stream_init():
@@ -26,7 +27,7 @@ def test_stream_init_with_options(init_cuda):
2627

2728
def test_stream_handle(init_cuda):
2829
stream = Device().create_stream(options=StreamOptions())
29-
assert isinstance(stream.handle, int)
30+
assert isinstance(stream.handle, driver.CUstream)
3031

3132

3233
def test_stream_is_nonblocking(init_cuda):
@@ -90,7 +91,8 @@ def test_stream_from_foreign_stream(init_cuda):
9091
device = Device()
9192
other_stream = device.create_stream(options=StreamOptions())
9293
stream = device.create_stream(obj=other_stream)
93-
assert other_stream.handle == stream.handle
94+
# convert to int to work around NVIDIA/cuda-python#465
95+
assert int(other_stream.handle) == int(stream.handle)
9496
device = stream.device
9597
assert isinstance(device, Device)
9698
context = stream.context

0 commit comments

Comments
 (0)