Skip to content

Commit 8ce0aa6

Browse files
committed
add tests for creating views directly
1 parent 80b2556 commit 8ce0aa6

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

cuda_core/tests/test_utils.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,22 @@ def convert_strides_to_counts(strides, itemsize):
2424
np.empty((3, 4), order='F'),
2525
)
2626
)
27-
def test_viewable_cpu(in_arr):
27+
class TestViewCPU:
2828

29-
@viewable((0,))
30-
def my_func(arr):
31-
view = arr.view(-1)
29+
def test_viewable_cpu(self, in_arr):
30+
31+
@viewable((0,))
32+
def my_func(arr):
33+
view = arr.view(-1)
34+
self._check_view(view, in_arr)
35+
36+
my_func(in_arr)
37+
38+
def test_strided_memory_view_cpu(self, in_arr):
39+
view = StridedMemoryView(in_arr, stream_ptr=-1)
40+
self._check_view(view, in_arr)
41+
42+
def _check_view(self, view, in_arr):
3243
assert isinstance(view, StridedMemoryView)
3344
assert view.ptr == in_arr.ctypes.data
3445
assert view.shape == in_arr.shape
@@ -43,8 +54,6 @@ def my_func(arr):
4354
assert view.device_accessible == False
4455
assert view.exporting_obj is in_arr
4556

46-
my_func(in_arr)
47-
4857

4958
def gpu_array_samples():
5059
# TODO: this function would initialize the device at test collection time
@@ -78,15 +87,33 @@ def gpu_array_ptr(arr):
7887
*gpu_array_samples(),
7988
)
8089
)
81-
def test_viewable_gpu(in_arr, stream):
82-
# TODO: use the device fixture?
83-
dev = Device()
84-
dev.set_current()
85-
s = dev.create_stream() if stream else None
86-
87-
@viewable((0,))
88-
def my_func(arr):
89-
view = arr.view(s.handle if s else -1)
90+
class TestViewGPU:
91+
92+
def test_viewable_gpu(self, in_arr, stream):
93+
# TODO: use the device fixture?
94+
dev = Device()
95+
dev.set_current()
96+
s = dev.create_stream() if stream else None
97+
98+
@viewable((0,))
99+
def my_func(arr):
100+
view = arr.view(s.handle if s else -1)
101+
self._check_view(view, in_arr, dev)
102+
103+
my_func(in_arr)
104+
105+
def test_strided_memory_view_cpu(self, in_arr, stream):
106+
# TODO: use the device fixture?
107+
dev = Device()
108+
dev.set_current()
109+
s = dev.create_stream() if stream else None
110+
111+
view = StridedMemoryView(
112+
in_arr,
113+
stream_ptr=s.handle if s else -1)
114+
self._check_view(view, in_arr, dev)
115+
116+
def _check_view(self, view, in_arr, dev):
90117
assert isinstance(view, StridedMemoryView)
91118
assert view.ptr == gpu_array_ptr(in_arr)
92119
assert view.shape == in_arr.shape
@@ -100,5 +127,3 @@ def my_func(arr):
100127
assert view.device_id == dev.device_id
101128
assert view.device_accessible == True
102129
assert view.exporting_obj is in_arr
103-
104-
my_func(in_arr)

0 commit comments

Comments
 (0)