Skip to content

Commit 15ebef3

Browse files
committed
add tests for viewable
1 parent 6820f30 commit 15ebef3

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

cuda_core/cuda/core/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from cuda.core.experimental._launcher import LaunchConfig, launch
88
from cuda.core.experimental._program import Program
99
from cuda.core.experimental._stream import Stream, StreamOptions
10+
from cuda.core.experimental import utils

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ cdef class StridedMemoryView:
2626
shape: tuple = None
2727
strides: tuple = None # in counts, not bytes
2828
dtype: numpy.dtype = None
29-
device_id: int = None # -1 for CPU
29+
device_id: int = None # 0 for CPU
3030
device_accessible: bool = None
3131
readonly: bool = None
3232
exporting_obj: Any = None

cuda_core/tests/test_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
try:
3+
import cupy as cp
4+
except ImportError:
5+
cp = None
6+
import pytest
7+
8+
from cuda.core.experimental import Device
9+
from cuda.core.experimental.utils import StridedMemoryView, viewable
10+
11+
12+
@pytest.mark.parametrize(
13+
"in_arr,", (
14+
np.empty(3, dtype=np.int32),
15+
np.empty((6, 6), dtype=np.float64)[::2, ::2],
16+
np.empty((3, 4), order='F'),
17+
)
18+
)
19+
def test_viewable_cpu(in_arr):
20+
21+
@viewable((0,))
22+
def my_func(arr):
23+
view = arr.view(-1)
24+
assert view.ptr == in_arr.ctypes.data
25+
assert view.shape == in_arr.shape
26+
if in_arr.flags.c_contiguous:
27+
assert view.strides is None
28+
else:
29+
assert view.strides == tuple(s // in_arr.dtype.itemsize for s in in_arr.strides)
30+
assert view.dtype == in_arr.dtype
31+
assert view.device_id == 0
32+
assert view.device_accessible == False
33+
assert view.exporting_obj is in_arr
34+
35+
my_func(in_arr)
36+
37+
38+
if cp is not None:
39+
40+
@pytest.mark.parametrize(
41+
"in_arr,stream", (
42+
(cp.empty(3, dtype=cp.complex64), None),
43+
(cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True),
44+
(cp.empty((3, 4), order='F'), True),
45+
)
46+
)
47+
def test_viewable_gpu(in_arr, stream):
48+
# TODO: use the device fixture?
49+
dev = Device()
50+
dev.set_current()
51+
s = dev.create_stream() if stream else None
52+
53+
@viewable((0,))
54+
def my_func(arr):
55+
view = arr.view(s.handle if s else -1)
56+
assert view.ptr == in_arr.data.ptr
57+
assert view.shape == in_arr.shape
58+
strides_in_counts = tuple(s // in_arr.dtype.itemsize for s in in_arr.strides)
59+
if in_arr.flags.c_contiguous:
60+
assert view.strides in (None, strides_in_counts)
61+
else:
62+
assert view.strides == strides_in_counts
63+
assert view.dtype == in_arr.dtype
64+
assert view.device_id == dev.device_id
65+
assert view.device_accessible == True
66+
assert view.exporting_obj is in_arr
67+
68+
my_func(in_arr)

0 commit comments

Comments
 (0)