Skip to content

Commit 721f9b2

Browse files
dpt.asarray can not migrate arrays across devices
Within compute follows data programming model, we recommend explicit migration of inputs to the same allocation queue. This is to be done using `usm_ndarray.to_device` method, or with `dpctl.tensor.asarray` operation. This operation until this commit was not able to migrate data across devices. This commit adds such an ability by copying via host. Test was added.
1 parent 52773aa commit 721f9b2

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

dpctl/tensor/_ctors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,15 @@ def _asarray_from_usm_ndarray(
186186
order=order,
187187
buffer_ctor_kwargs={"queue": copy_q},
188188
)
189-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
190-
src=usm_ndary, dst=res, sycl_queue=copy_q
191-
)
192-
hev.wait()
189+
eq = dpctl.utils.get_execution_queue([usm_ndary.sycl_queue, copy_q])
190+
if eq is not None:
191+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
192+
src=usm_ndary, dst=res, sycl_queue=eq
193+
)
194+
hev.wait()
195+
else:
196+
tmp = dpt.asnumpy(usm_ndary)
197+
res[...] = tmp
193198
return res
194199

195200

dpctl/tests/test_tensor_asarray.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pytest
19+
from helper import get_queue_or_skip
1920

2021
import dpctl
2122
import dpctl.tensor as dpt
@@ -193,10 +194,7 @@ def test_asarray_scalars():
193194

194195

195196
def test_asarray_copy_false():
196-
try:
197-
q = dpctl.SyclQueue()
198-
except dpctl.SyclQueueCreationError:
199-
pytest.skip("Could not create a queue")
197+
q = get_queue_or_skip()
200198
rng = np.random.default_rng()
201199
Xnp = rng.integers(low=-255, high=255, size=(10, 4), dtype=np.int64)
202200
X = dpt.from_numpy(Xnp, usm_type="device", sycl_queue=q)
@@ -229,10 +227,15 @@ def test_asarray_copy_false():
229227

230228

231229
def test_asarray_invalid_dtype():
232-
try:
233-
q = dpctl.SyclQueue()
234-
except dpctl.SyclQueueCreationError:
235-
pytest.skip("Could not create a queue")
230+
q = get_queue_or_skip()
236231
Xnp = np.array([1, 2, 3], dtype=object)
237232
with pytest.raises(TypeError):
238233
dpt.asarray(Xnp, sycl_queue=q)
234+
235+
236+
def test_asarray_cross_device():
237+
q = get_queue_or_skip()
238+
qprof = dpctl.SyclQueue(property="enable_profiling")
239+
x = dpt.empty(10, dtype="i8", sycl_queue=q)
240+
y = dpt.asarray(x, sycl_queue=qprof)
241+
assert y.sycl_queue == qprof

0 commit comments

Comments
 (0)