Skip to content

Commit 5d8eaaa

Browse files
committed
Update impacted tests with dpnp.__dlpack__
1 parent 6a33f0d commit 5d8eaaa

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

dpnp/tests/test_dlpack.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import dpctl
12
import numpy
23
import pytest
3-
from numpy.testing import assert_array_equal
4+
from numpy.testing import assert_array_equal, assert_raises
45

56
import dpnp
67

@@ -10,11 +11,20 @@
1011

1112

1213
class TestDLPack:
13-
@pytest.mark.parametrize("stream", [None, 1])
14+
@pytest.mark.parametrize("stream", [None, dpctl.SyclQueue()])
1415
def test_stream(self, stream):
1516
x = dpnp.arange(5)
1617
x.__dlpack__(stream=stream)
1718

19+
@pytest.mark.parametrize(
20+
"stream",
21+
[1, dict(), dpctl.SyclDevice()],
22+
ids=["scalar", "dictionary", "device"],
23+
)
24+
def test_invaid_stream(self, stream):
25+
x = dpnp.arange(5)
26+
assert_raises(TypeError, x.__dlpack__, stream=stream)
27+
1828
@pytest.mark.parametrize("copy", [True, None, False])
1929
def test_copy(self, copy):
2030
x = dpnp.arange(5)

dpnp/tests/third_party/cupy/core_tests/test_dlpack.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,13 @@ def test_stream(self):
206206
for src_s in [self._get_stream(s) for s in allowed_streams]:
207207
for dst_s in [self._get_stream(s) for s in allowed_streams]:
208208
orig_array = _gen_array(cupy.float32, alloc_q=src_s)
209-
dltensor = orig_array.__dlpack__(stream=orig_array)
209+
210+
q = dpctl.SyclQueue(
211+
orig_array.sycl_context,
212+
orig_array.sycl_device,
213+
property="enable_profiling",
214+
)
215+
dltensor = orig_array.__dlpack__(stream=q)
210216

211217
out_array = dlp.from_dlpack_capsule(dltensor)
212218
out_array = cupy.from_dlpack(out_array, device=dst_s)

0 commit comments

Comments
 (0)