Skip to content

Commit 67e53d3

Browse files
Add case of dlpack test to expand coverage
1 parent 6bb2d73 commit 67e53d3

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,20 @@ def test_copy_via_host():
702702
get_queue_or_skip()
703703
x = dpt.ones(1, dtype="i4")
704704
x_np = np.ones(1, dtype="i4")
705-
y = dpt.from_dlpack(x_np, device=x.__dlpack_device__())
705+
x_dl_dev = x.__dlpack_device__()
706+
y = dpt.from_dlpack(x_np, device=x_dl_dev)
706707
assert isinstance(y, dpt.usm_ndarray)
707708
assert y.sycl_device == x.sycl_device
708709
assert y.usm_type == "device"
710+
711+
with pytest.raises(ValueError):
712+
dpt.from_dlpack(x_np, device=(1, 0, 0))
713+
with pytest.raises(BufferError):
714+
dpt.from_dlpack(x, device=(2, 0))
715+
716+
num_devs = dpctl.get_num_devices()
717+
if num_devs > 1:
718+
j = [i for i in range(num_devs) if i != x_dl_dev[1]][0]
719+
z = dpt.from_dlpack(x, device=(x_dl_dev[0], j))
720+
assert isinstance(z, dpt.usm_ndarray)
721+
assert z.usm_type == "device"

0 commit comments

Comments
 (0)