Skip to content

Commit f1239a2

Browse files
committed
fix formatting
1 parent 8295d56 commit f1239a2

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

cuda_core/cuda/core/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
from cuda.core.experimental import utils
56
from cuda.core.experimental._device import Device
67
from cuda.core.experimental._event import EventOptions
78
from cuda.core.experimental._launcher import LaunchConfig, launch
89
from cuda.core.experimental._program import Program
910
from cuda.core.experimental._stream import Stream, StreamOptions
10-
from cuda.core.experimental import utils

cuda_core/tests/test_utils.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@ def convert_strides_to_counts(strides, itemsize):
2222

2323

2424
@pytest.mark.parametrize(
25-
"in_arr,", (
25+
"in_arr,",
26+
(
2627
np.empty(3, dtype=np.int32),
2728
np.empty((6, 6), dtype=np.float64)[::2, ::2],
28-
np.empty((3, 4), order='F'),
29-
)
29+
np.empty((3, 4), order="F"),
30+
),
3031
)
3132
class TestViewCPU:
32-
3333
def test_viewable_cpu(self, in_arr):
34-
3534
@viewable((0,))
3635
def my_func(arr):
3736
# stream_ptr=-1 means "the consumer does not care"
@@ -49,8 +48,7 @@ def _check_view(self, view, in_arr):
4948
assert isinstance(view, StridedMemoryView)
5049
assert view.ptr == in_arr.ctypes.data
5150
assert view.shape == in_arr.shape
52-
strides_in_counts = convert_strides_to_counts(
53-
in_arr.strides, in_arr.dtype.itemsize)
51+
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
5452
if in_arr.flags.c_contiguous:
5553
assert view.strides is None
5654
else:
@@ -68,7 +66,7 @@ def gpu_array_samples():
6866
samples += [
6967
(cp.empty(3, dtype=cp.complex64), None),
7068
(cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True),
71-
(cp.empty((3, 4), order='F'), True),
69+
(cp.empty((3, 4), order="F"), True),
7270
]
7371
# Numba's device_array is the only known array container that does not
7472
# support DLPack (so that we get to test the CAI coverage).
@@ -88,13 +86,8 @@ def gpu_array_ptr(arr):
8886
assert False, f"{arr=}"
8987

9088

91-
@pytest.mark.parametrize(
92-
"in_arr,stream", (
93-
*gpu_array_samples(),
94-
)
95-
)
89+
@pytest.mark.parametrize("in_arr,stream", (*gpu_array_samples(),))
9690
class TestViewGPU:
97-
9891
def test_viewable_gpu(self, in_arr, stream):
9992
# TODO: use the device fixture?
10093
dev = Device()
@@ -116,17 +109,14 @@ def test_strided_memory_view_cpu(self, in_arr, stream):
116109
# This is the consumer stream
117110
s = dev.create_stream() if stream else None
118111

119-
view = StridedMemoryView(
120-
in_arr,
121-
stream_ptr=s.handle if s else -1)
112+
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
122113
self._check_view(view, in_arr, dev)
123114

124115
def _check_view(self, view, in_arr, dev):
125116
assert isinstance(view, StridedMemoryView)
126117
assert view.ptr == gpu_array_ptr(in_arr)
127118
assert view.shape == in_arr.shape
128-
strides_in_counts = convert_strides_to_counts(
129-
in_arr.strides, in_arr.dtype.itemsize)
119+
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
130120
if in_arr.flags["C_CONTIGUOUS"]:
131121
assert view.strides in (None, strides_in_counts)
132122
else:

0 commit comments

Comments
 (0)