Skip to content

Commit b950e84

Browse files
committed
Add test coverage
1 parent cde6795 commit b950e84

File tree

1 file changed

+43
-15
lines changed

1 file changed

+43
-15
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,23 +1932,51 @@ def test_svd(shape, full_matrices, compute_uv, device):
19321932
assert_sycl_queue_equal(dpnp_s_queue, expected_queue)
19331933

19341934

1935-
@pytest.mark.parametrize(
1936-
"device_from",
1937-
valid_devices,
1938-
ids=[device.filter_string for device in valid_devices],
1939-
)
1940-
@pytest.mark.parametrize(
1941-
"device_to",
1942-
valid_devices,
1943-
ids=[device.filter_string for device in valid_devices],
1944-
)
1945-
def test_to_device(device_from, device_to):
1946-
data = [1.0, 1.0, 1.0, 1.0, 1.0]
1935+
class TestToDevice:
1936+
@pytest.mark.parametrize(
1937+
"device_from",
1938+
valid_devices,
1939+
ids=[device.filter_string for device in valid_devices],
1940+
)
1941+
@pytest.mark.parametrize(
1942+
"device_to",
1943+
valid_devices,
1944+
ids=[device.filter_string for device in valid_devices],
1945+
)
1946+
def test_basic(self, device_from, device_to):
1947+
data = [1.0, 1.0, 1.0, 1.0, 1.0]
1948+
x = dpnp.array(data, dtype=dpnp.float32, device=device_from)
1949+
1950+
y = x.to_device(device_to)
1951+
assert y.sycl_device == device_to
1952+
1953+
def test_to_queue(self):
1954+
x = dpnp.full(100, 2, dtype=dpnp.int64)
1955+
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1956+
1957+
y = x.to_device(q_prof)
1958+
assert_sycl_queue_equal(y.sycl_queue, q_prof)
1959+
1960+
def test_stream(self):
1961+
x = dpnp.full(100, 2, dtype=dpnp.int64)
1962+
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1963+
q_exec = dpctl.SyclQueue(x.sycl_device)
1964+
1965+
y = x.to_device(q_prof, stream=q_exec)
1966+
assert_sycl_queue_equal(y.sycl_queue, q_prof)
1967+
1968+
q_exec = dpctl.SyclQueue(x.sycl_device)
1969+
_ = dpnp.linspace(0, 20, num=10**5, sycl_queue=q_exec)
1970+
y = x.to_device(q_prof, stream=q_exec)
1971+
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19471972

1948-
x = dpnp.array(data, dtype=dpnp.float32, device=device_from)
1949-
y = x.to_device(device_to)
1973+
def test_stream_no_sync(self):
1974+
x = dpnp.full(100, 2, dtype=dpnp.int64)
1975+
q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
19501976

1951-
assert y.sycl_device == device_to
1977+
for stream in [None, 1, dpctl.SyclDevice(), x.sycl_queue]:
1978+
y = x.to_device(q_prof, stream=stream)
1979+
assert_sycl_queue_equal(y.sycl_queue, q_prof)
19521980

19531981

19541982
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)