@@ -22,16 +22,15 @@ def convert_strides_to_counts(strides, itemsize):
22
22
23
23
24
24
@pytest .mark .parametrize (
25
- "in_arr," , (
25
+ "in_arr," ,
26
+ (
26
27
np .empty (3 , dtype = np .int32 ),
27
28
np .empty ((6 , 6 ), dtype = np .float64 )[::2 , ::2 ],
28
- np .empty ((3 , 4 ), order = 'F' ),
29
- )
29
+ np .empty ((3 , 4 ), order = "F" ),
30
+ ),
30
31
)
31
32
class TestViewCPU :
32
-
33
33
def test_viewable_cpu (self , in_arr ):
34
-
35
34
@viewable ((0 ,))
36
35
def my_func (arr ):
37
36
# stream_ptr=-1 means "the consumer does not care"
@@ -49,8 +48,7 @@ def _check_view(self, view, in_arr):
49
48
assert isinstance (view , StridedMemoryView )
50
49
assert view .ptr == in_arr .ctypes .data
51
50
assert view .shape == in_arr .shape
52
- strides_in_counts = convert_strides_to_counts (
53
- in_arr .strides , in_arr .dtype .itemsize )
51
+ strides_in_counts = convert_strides_to_counts (in_arr .strides , in_arr .dtype .itemsize )
54
52
if in_arr .flags .c_contiguous :
55
53
assert view .strides is None
56
54
else :
@@ -68,7 +66,7 @@ def gpu_array_samples():
68
66
samples += [
69
67
(cp .empty (3 , dtype = cp .complex64 ), None ),
70
68
(cp .empty ((6 , 6 ), dtype = cp .float64 )[::2 , ::2 ], True ),
71
- (cp .empty ((3 , 4 ), order = 'F' ), True ),
69
+ (cp .empty ((3 , 4 ), order = "F" ), True ),
72
70
]
73
71
# Numba's device_array is the only known array container that does not
74
72
# support DLPack (so that we get to test the CAI coverage).
@@ -88,13 +86,8 @@ def gpu_array_ptr(arr):
88
86
assert False , f"{ arr = } "
89
87
90
88
91
- @pytest .mark .parametrize (
92
- "in_arr,stream" , (
93
- * gpu_array_samples (),
94
- )
95
- )
89
+ @pytest .mark .parametrize ("in_arr,stream" , (* gpu_array_samples (),))
96
90
class TestViewGPU :
97
-
98
91
def test_viewable_gpu (self , in_arr , stream ):
99
92
# TODO: use the device fixture?
100
93
dev = Device ()
@@ -116,17 +109,14 @@ def test_strided_memory_view_cpu(self, in_arr, stream):
116
109
# This is the consumer stream
117
110
s = dev .create_stream () if stream else None
118
111
119
- view = StridedMemoryView (
120
- in_arr ,
121
- stream_ptr = s .handle if s else - 1 )
112
+ view = StridedMemoryView (in_arr , stream_ptr = s .handle if s else - 1 )
122
113
self ._check_view (view , in_arr , dev )
123
114
124
115
def _check_view (self , view , in_arr , dev ):
125
116
assert isinstance (view , StridedMemoryView )
126
117
assert view .ptr == gpu_array_ptr (in_arr )
127
118
assert view .shape == in_arr .shape
128
- strides_in_counts = convert_strides_to_counts (
129
- in_arr .strides , in_arr .dtype .itemsize )
119
+ strides_in_counts = convert_strides_to_counts (in_arr .strides , in_arr .dtype .itemsize )
130
120
if in_arr .flags ["C_CONTIGUOUS" ]:
131
121
assert view .strides in (None , strides_in_counts )
132
122
else :
0 commit comments