Skip to content

Commit 6a33f0d

Browse files
committed
Update tests for dpnp.to_device
1 parent 0dfcfed commit 6a33f0d

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1949,12 +1949,14 @@ def test_basic(self, device_from, device_to):
19491949

19501950
y = x.to_device(device_to)
19511951
assert y.sycl_device == device_to
1952+
assert (x.asnumpy() == y.asnumpy()).all()
19521953

19531954
def test_to_queue(self):
19541955
x = dpnp.full(100, 2, dtype=dpnp.int64)
19551956
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
19561957

19571958
y = x.to_device(q_prof)
1959+
assert (x.asnumpy() == y.asnumpy()).all()
19581960
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19591961

19601962
def test_stream(self):
@@ -1963,21 +1965,34 @@ def test_stream(self):
19631965
q_exec = dpctl.SyclQueue(x.sycl_device)
19641966

19651967
y = x.to_device(q_prof, stream=q_exec)
1968+
assert (x.asnumpy() == y.asnumpy()).all()
19661969
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19671970

19681971
q_exec = dpctl.SyclQueue(x.sycl_device)
19691972
_ = dpnp.linspace(0, 20, num=10**5, sycl_queue=q_exec)
19701973
y = x.to_device(q_prof, stream=q_exec)
1974+
assert (x.asnumpy() == y.asnumpy()).all()
19711975
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19721976

19731977
def test_stream_no_sync(self):
19741978
x = dpnp.full(100, 2, dtype=dpnp.int64)
19751979
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
19761980

1977-
for stream in [None, 1, dpctl.SyclDevice(), x.sycl_queue]:
1981+
for stream in [None, x.sycl_queue]:
19781982
y = x.to_device(q_prof, stream=stream)
1983+
assert (x.asnumpy() == y.asnumpy()).all()
19791984
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19801985

1986+
@pytest.mark.parametrize(
1987+
"stream",
1988+
[1, dict(), dpctl.SyclDevice()],
1989+
ids=["scalar", "dictionary", "device"],
1990+
)
1991+
def test_invalid_stream(self, stream):
1992+
x = dpnp.ones(2, dtype=dpnp.int64)
1993+
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1994+
assert_raises(TypeError, x.to_device, q_prof, stream=stream)
1995+
19811996

19821997
@pytest.mark.parametrize(
19831998
"device",

0 commit comments

Comments
 (0)