1
- import numpy as np
2
1
try :
3
2
import cupy as cp
4
3
except ImportError :
5
4
cp = None
5
+ try :
6
+ from numba import cuda as numba_cuda
7
+ except ImportError :
8
+ numba_cuda = None
9
+ import numpy as np
6
10
import pytest
7
11
8
12
from cuda .core .experimental import Device
9
13
from cuda .core .experimental .utils import StridedMemoryView , viewable
10
14
11
15
16
+ def convert_strides_to_counts (strides , itemsize ):
17
+ return tuple (s // itemsize for s in strides )
18
+
19
+
12
20
@pytest .mark .parametrize (
13
21
"in_arr," , (
14
22
np .empty (3 , dtype = np .int32 ),
@@ -21,12 +29,15 @@ def test_viewable_cpu(in_arr):
21
29
@viewable ((0 ,))
22
30
def my_func (arr ):
23
31
view = arr .view (- 1 )
32
+ assert isinstance (view , StridedMemoryView )
24
33
assert view .ptr == in_arr .ctypes .data
25
34
assert view .shape == in_arr .shape
35
+ strides_in_counts = convert_strides_to_counts (
36
+ in_arr .strides , in_arr .dtype .itemsize )
26
37
if in_arr .flags .c_contiguous :
27
38
assert view .strides is None
28
39
else :
29
- assert view .strides == tuple ( s // in_arr . dtype . itemsize for s in in_arr . strides )
40
+ assert view .strides == strides_in_counts
30
41
assert view .dtype == in_arr .dtype
31
42
assert view .device_id == 0
32
43
assert view .device_accessible == False
@@ -35,34 +46,59 @@ def my_func(arr):
35
46
my_func (in_arr )
36
47
37
48
38
- if cp is not None :
39
-
40
- @pytest .mark .parametrize (
41
- "in_arr,stream" , (
49
+ def gpu_array_samples ():
50
+ # TODO: this function would initialize the device at test collection time
51
+ samples = []
52
+ if cp is not None :
53
+ samples += [
42
54
(cp .empty (3 , dtype = cp .complex64 ), None ),
43
55
(cp .empty ((6 , 6 ), dtype = cp .float64 )[::2 , ::2 ], True ),
44
56
(cp .empty ((3 , 4 ), order = 'F' ), True ),
45
- )
57
+ ]
58
+ # Numba's device_array is the only known array container that does not
59
+ # support DLPack (so that we get to test the CAI coverage).
60
+ if numba_cuda is not None :
61
+ samples += [
62
+ (numba_cuda .device_array ((2 ,), dtype = np .int8 ), None ),
63
+ (numba_cuda .device_array ((4 , 2 ), dtype = np .float32 ), True ),
64
+ ]
65
+ return samples
66
+
67
+
68
+ def gpu_array_ptr (arr ):
69
+ if cp is not None and isinstance (arr , cp .ndarray ):
70
+ return arr .data .ptr
71
+ if numba_cuda is not None and isinstance (arr , numba_cuda .cudadrv .devicearray .DeviceNDArray ):
72
+ return arr .device_ctypes_pointer .value
73
+ assert False , f"{ arr = } "
74
+
75
+
76
+ @pytest .mark .parametrize (
77
+ "in_arr,stream" , (
78
+ * gpu_array_samples (),
46
79
)
47
- def test_viewable_gpu (in_arr , stream ):
48
- # TODO: use the device fixture?
49
- dev = Device ()
50
- dev .set_current ()
51
- s = dev .create_stream () if stream else None
80
+ )
81
+ def test_viewable_gpu (in_arr , stream ):
82
+ # TODO: use the device fixture?
83
+ dev = Device ()
84
+ dev .set_current ()
85
+ s = dev .create_stream () if stream else None
86
+
87
+ @viewable ((0 ,))
88
+ def my_func (arr ):
89
+ view = arr .view (s .handle if s else - 1 )
90
+ assert isinstance (view , StridedMemoryView )
91
+ assert view .ptr == gpu_array_ptr (in_arr )
92
+ assert view .shape == in_arr .shape
93
+ strides_in_counts = convert_strides_to_counts (
94
+ in_arr .strides , in_arr .dtype .itemsize )
95
+ if in_arr .flags ["C_CONTIGUOUS" ]:
96
+ assert view .strides in (None , strides_in_counts )
97
+ else :
98
+ assert view .strides == strides_in_counts
99
+ assert view .dtype == in_arr .dtype
100
+ assert view .device_id == dev .device_id
101
+ assert view .device_accessible == True
102
+ assert view .exporting_obj is in_arr
52
103
53
- @viewable ((0 ,))
54
- def my_func (arr ):
55
- view = arr .view (s .handle if s else - 1 )
56
- assert view .ptr == in_arr .data .ptr
57
- assert view .shape == in_arr .shape
58
- strides_in_counts = tuple (s // in_arr .dtype .itemsize for s in in_arr .strides )
59
- if in_arr .flags .c_contiguous :
60
- assert view .strides in (None , strides_in_counts )
61
- else :
62
- assert view .strides == strides_in_counts
63
- assert view .dtype == in_arr .dtype
64
- assert view .device_id == dev .device_id
65
- assert view .device_accessible == True
66
- assert view .exporting_obj is in_arr
67
-
68
- my_func (in_arr )
104
+ my_func (in_arr )
0 commit comments