Skip to content

Commit 80b2556

Browse files
committed
use numba to test CAI & make test a bit cleaner
1 parent 15ebef3 commit 80b2556

File tree

1 file changed

+64
-28
lines changed

1 file changed

+64
-28
lines changed

cuda_core/tests/test_utils.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
import numpy as np
21
try:
32
import cupy as cp
43
except ImportError:
54
cp = None
5+
try:
6+
from numba import cuda as numba_cuda
7+
except ImportError:
8+
numba_cuda = None
9+
import numpy as np
610
import pytest
711

812
from cuda.core.experimental import Device
913
from cuda.core.experimental.utils import StridedMemoryView, viewable
1014

1115

16+
def convert_strides_to_counts(strides, itemsize):
17+
return tuple(s // itemsize for s in strides)
18+
19+
1220
@pytest.mark.parametrize(
1321
"in_arr,", (
1422
np.empty(3, dtype=np.int32),
@@ -21,12 +29,15 @@ def test_viewable_cpu(in_arr):
2129
@viewable((0,))
2230
def my_func(arr):
2331
view = arr.view(-1)
32+
assert isinstance(view, StridedMemoryView)
2433
assert view.ptr == in_arr.ctypes.data
2534
assert view.shape == in_arr.shape
35+
strides_in_counts = convert_strides_to_counts(
36+
in_arr.strides, in_arr.dtype.itemsize)
2637
if in_arr.flags.c_contiguous:
2738
assert view.strides is None
2839
else:
29-
assert view.strides == tuple(s // in_arr.dtype.itemsize for s in in_arr.strides)
40+
assert view.strides == strides_in_counts
3041
assert view.dtype == in_arr.dtype
3142
assert view.device_id == 0
3243
assert view.device_accessible == False
@@ -35,34 +46,59 @@ def my_func(arr):
3546
my_func(in_arr)
3647

3748

38-
if cp is not None:
39-
40-
@pytest.mark.parametrize(
41-
"in_arr,stream", (
49+
def gpu_array_samples():
50+
# TODO: this function would initialize the device at test collection time
51+
samples = []
52+
if cp is not None:
53+
samples += [
4254
(cp.empty(3, dtype=cp.complex64), None),
4355
(cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True),
4456
(cp.empty((3, 4), order='F'), True),
45-
)
57+
]
58+
# Numba's device_array is the only known array container that does not
59+
# support DLPack (so that we get to test the CAI coverage).
60+
if numba_cuda is not None:
61+
samples += [
62+
(numba_cuda.device_array((2,), dtype=np.int8), None),
63+
(numba_cuda.device_array((4, 2), dtype=np.float32), True),
64+
]
65+
return samples
66+
67+
68+
def gpu_array_ptr(arr):
69+
if cp is not None and isinstance(arr, cp.ndarray):
70+
return arr.data.ptr
71+
if numba_cuda is not None and isinstance(arr, numba_cuda.cudadrv.devicearray.DeviceNDArray):
72+
return arr.device_ctypes_pointer.value
73+
assert False, f"{arr=}"
74+
75+
76+
@pytest.mark.parametrize(
77+
"in_arr,stream", (
78+
*gpu_array_samples(),
4679
)
47-
def test_viewable_gpu(in_arr, stream):
48-
# TODO: use the device fixture?
49-
dev = Device()
50-
dev.set_current()
51-
s = dev.create_stream() if stream else None
80+
)
81+
def test_viewable_gpu(in_arr, stream):
82+
# TODO: use the device fixture?
83+
dev = Device()
84+
dev.set_current()
85+
s = dev.create_stream() if stream else None
86+
87+
@viewable((0,))
88+
def my_func(arr):
89+
view = arr.view(s.handle if s else -1)
90+
assert isinstance(view, StridedMemoryView)
91+
assert view.ptr == gpu_array_ptr(in_arr)
92+
assert view.shape == in_arr.shape
93+
strides_in_counts = convert_strides_to_counts(
94+
in_arr.strides, in_arr.dtype.itemsize)
95+
if in_arr.flags["C_CONTIGUOUS"]:
96+
assert view.strides in (None, strides_in_counts)
97+
else:
98+
assert view.strides == strides_in_counts
99+
assert view.dtype == in_arr.dtype
100+
assert view.device_id == dev.device_id
101+
assert view.device_accessible == True
102+
assert view.exporting_obj is in_arr
52103

53-
@viewable((0,))
54-
def my_func(arr):
55-
view = arr.view(s.handle if s else -1)
56-
assert view.ptr == in_arr.data.ptr
57-
assert view.shape == in_arr.shape
58-
strides_in_counts = tuple(s // in_arr.dtype.itemsize for s in in_arr.strides)
59-
if in_arr.flags.c_contiguous:
60-
assert view.strides in (None, strides_in_counts)
61-
else:
62-
assert view.strides == strides_in_counts
63-
assert view.dtype == in_arr.dtype
64-
assert view.device_id == dev.device_id
65-
assert view.device_accessible == True
66-
assert view.exporting_obj is in_arr
67-
68-
my_func(in_arr)
104+
my_func(in_arr)

0 commit comments

Comments
 (0)