Skip to content

Commit 074ec3a

Browse files
Merge pull request #1331 from IntelPython/to_device-stream-support
usm_ndarray.to_device(dev, stream=queue) support
2 parents 6f0969c + 706d80f commit 074ec3a

File tree

6 files changed

+129
-31
lines changed

6 files changed

+129
-31
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,23 @@ def _broadcast_shapes(sh1, sh2):
246246
).shape
247247

248248

249+
def _broadcast_strides(X_shape, X_strides, res_ndim):
250+
"""
251+
Broadcasts strides to match the given dimensions;
252+
returns tuple type strides.
253+
"""
254+
out_strides = [0] * res_ndim
255+
X_shape_len = len(X_shape)
256+
str_dim = -X_shape_len
257+
for i in range(X_shape_len):
258+
shape_value = X_shape[i]
259+
if not shape_value == 1:
260+
out_strides[str_dim] = X_strides[i]
261+
str_dim += 1
262+
263+
return tuple(out_strides)
264+
265+
249266
def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
250267
if any(
251268
not isinstance(arg, dpt.usm_ndarray)
@@ -268,7 +285,7 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
268285
except ValueError as exc:
269286
raise ValueError("Shapes of two arrays are not compatible") from exc
270287

271-
if dst.size < src.size:
288+
if dst.size < src.size and dst.size < np.prod(common_shape):
272289
raise ValueError("Destination is smaller ")
273290

274291
if len(common_shape) > dst.ndim:
@@ -279,13 +296,33 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
279296
common_shape = common_shape[ones_count:]
280297

281298
if src.ndim < len(common_shape):
282-
new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides
299+
new_src_strides = _broadcast_strides(
300+
src.shape, src.strides, len(common_shape)
301+
)
302+
src_same_shape = dpt.usm_ndarray(
303+
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
304+
)
305+
elif src.ndim == len(common_shape):
306+
new_src_strides = _broadcast_strides(
307+
src.shape, src.strides, len(common_shape)
308+
)
283309
src_same_shape = dpt.usm_ndarray(
284310
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
285311
)
286312
else:
287-
src_same_shape = src
288-
src_same_shape.shape = common_shape
313+
# since broadcasting succeeded, src.ndim is greater because of
314+
# leading sequence of ones, so we trim it
315+
n = len(common_shape)
316+
new_src_strides = _broadcast_strides(
317+
src.shape[-n:], src.strides[-n:], n
318+
)
319+
src_same_shape = dpt.usm_ndarray(
320+
common_shape,
321+
dtype=src.dtype,
322+
buffer=src.usm_data,
323+
strides=new_src_strides,
324+
offset=src._element_offset,
325+
)
289326

290327
_copy_same_shape(dst, src_same_shape)
291328

dpctl/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,5 +343,5 @@ def nonzero(arr):
343343
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(arr)}"
344344
)
345345
if arr.ndim == 0:
346-
raise ValueError("Array of positive rank is exepcted")
346+
raise ValueError("Array of positive rank is expected")
347347
return _nonzero_impl(arr)

dpctl/tensor/_manipulation_functions.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import dpctl.tensor._tensor_impl as ti
2626
import dpctl.utils as dputils
2727

28+
from ._copy_utils import _broadcast_strides
2829
from ._type_utils import _to_device_supported_dtype
2930

3031
__doc__ = (
@@ -120,23 +121,6 @@ def __repr__(self):
120121
return self._finfo.__repr__()
121122

122123

123-
def _broadcast_strides(X_shape, X_strides, res_ndim):
124-
"""
125-
Broadcasts strides to match the given dimensions;
126-
returns tuple type strides.
127-
"""
128-
out_strides = [0] * res_ndim
129-
X_shape_len = len(X_shape)
130-
str_dim = -X_shape_len
131-
for i in range(X_shape_len):
132-
shape_value = X_shape[i]
133-
if not shape_value == 1:
134-
out_strides[str_dim] = X_strides[i]
135-
str_dim += 1
136-
137-
return tuple(out_strides)
138-
139-
140124
def _broadcast_shape_impl(shapes):
141125
if len(set(shapes)) == 1:
142126
return shapes[0]

dpctl/tensor/_stride_utils.pxi

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ cdef int _from_input_shape_strides(
6464
cdef int j
6565
cdef bint all_incr = 1
6666
cdef bint all_decr = 1
67-
cdef bint all_incr_modified = 0
68-
cdef bint all_decr_modified = 0
67+
cdef bint strides_inspected = 0
6968
cdef Py_ssize_t elem_count = 1
7069
cdef Py_ssize_t min_shift = 0
7170
cdef Py_ssize_t max_shift = 0
@@ -167,27 +166,33 @@ cdef int _from_input_shape_strides(
167166
while (j < nd and shape_arr[j] == 1):
168167
j = j + 1
169168
if j < nd:
169+
strides_inspected = 1
170170
if all_incr:
171-
all_incr_modified = 1
172171
all_incr = (
173172
(strides_arr[i] > 0) and
174173
(strides_arr[j] > 0) and
175174
(strides_arr[i] <= strides_arr[j])
176175
)
177176
if all_decr:
178-
all_decr_modified = 1
179177
all_decr = (
180178
(strides_arr[i] > 0) and
181179
(strides_arr[j] > 0) and
182180
(strides_arr[i] >= strides_arr[j])
183181
)
184182
i = j
185183
else:
184+
if not strides_inspected:
185+
# all dimensions have size 1 except
186+
# dimension 'i'. Array is both C and F
187+
# contiguous
188+
strides_inspected = 1
189+
all_incr = (strides_arr[i] == 1)
190+
all_decr = all_incr
186191
break
187192
# should only set contig flags on actually obtained
188193
# values, rather than default values
189-
all_incr = all_incr and all_incr_modified
190-
all_decr = all_decr and all_decr_modified
194+
all_incr = all_incr and strides_inspected
195+
all_decr = all_decr and strides_inspected
191196
if all_incr and all_decr:
192197
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
193198
elif all_incr:

dpctl/tensor/_usmarray.pyx

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ cdef class usm_ndarray:
816816
return _take_multi_index(res, adv_ind, adv_ind_start_p)
817817

818818

819-
def to_device(self, target):
819+
def to_device(self, target, stream=None):
820820
""" to_device(target_device)
821821
822822
Transfers this array to specified target device.
@@ -856,6 +856,14 @@ cdef class usm_ndarray:
856856
cdef c_dpctl.DPCTLSyclQueueRef QRef = NULL
857857
cdef c_dpmem._Memory arr_buf
858858
d = Device.create_device(target)
859+
860+
if (stream is None or type(stream) is not dpctl.SyclQueue or
861+
stream == self.sycl_queue):
862+
pass
863+
else:
864+
ev = self.sycl_queue.submit_barrier()
865+
stream.submit_barrier(dependent_events=[ev])
866+
859867
if (d.sycl_context == self.sycl_context):
860868
arr_buf = <c_dpmem._Memory> self.usm_data
861869
QRef = (<c_dpctl.SyclQueue> d.sycl_queue).get_queue_ref()
@@ -1167,8 +1175,6 @@ cdef class usm_ndarray:
11671175
if adv_ind_start_p < 0:
11681176
# basic slicing
11691177
if isinstance(rhs, usm_ndarray):
1170-
if Xv.size == 0:
1171-
return
11721178
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
11731179
else:
11741180
if hasattr(rhs, "__sycl_usm_array_interface__"):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,25 @@ def test_usm_ndarray_flags():
109109
x.flags["C"] = False
110110

111111

112+
def test_usm_ndarray_flags_bug_gh_1334():
113+
get_queue_or_skip()
114+
a = dpt.ones((2, 3), dtype="u4")
115+
r = dpt.reshape(a, (1, 6, 1))
116+
assert r.flags["C"] and r.flags["F"]
117+
118+
a = dpt.ones((2, 3), dtype="u4", order="F")
119+
r = dpt.reshape(a, (1, 6, 1), order="F")
120+
assert r.flags["C"] and r.flags["F"]
121+
122+
a = dpt.ones((2, 3, 4), dtype="i8")
123+
r = dpt.sum(a, axis=(1, 2), keepdims=True)
124+
assert r.flags["C"] and r.flags["F"]
125+
126+
a = dpt.ones((2, 1), dtype="?")
127+
r = a[:, 1::-1]
128+
assert r.flags["F"] and r.flags["C"]
129+
130+
112131
@pytest.mark.parametrize(
113132
"dtype",
114133
[
@@ -1012,6 +1031,53 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
10121031
Zusm_empty[Ellipsis] = Zusm_3d[0, 0, 0:0]
10131032

10141033

1034+
def test_setitem_broadcasting():
1035+
get_queue_or_skip()
1036+
dst = dpt.ones((2, 3, 4), dtype="u4")
1037+
src = dpt.zeros((3, 1), dtype=dst.dtype)
1038+
dst[...] = src
1039+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1040+
assert np.array_equal(dpt.asnumpy(dst), expected)
1041+
1042+
1043+
def test_setitem_broadcasting_empty_dst_validation():
1044+
"Broadcasting rules apply, except exception"
1045+
get_queue_or_skip()
1046+
dst = dpt.ones((2, 0, 5, 4), dtype="i8")
1047+
src = dpt.ones((2, 0, 3, 4), dtype="i8")
1048+
with pytest.raises(ValueError):
1049+
dst[...] = src
1050+
1051+
1052+
def test_setitem_broadcasting_empty_dst_edge_case():
1053+
"""RHS is shunken to empty array by
1054+
broadasting rule, hence no exception"""
1055+
get_queue_or_skip()
1056+
dst = dpt.ones(1, dtype="i8")[0:0]
1057+
src = dpt.ones(tuple(), dtype="i8")
1058+
dst[...] = src
1059+
1060+
1061+
def test_setitem_broadcasting_src_ndim_equal_dst_ndim():
1062+
get_queue_or_skip()
1063+
dst = dpt.ones((2, 3, 4), dtype="i4")
1064+
src = dpt.zeros((2, 1, 4), dtype="i4")
1065+
dst[...] = src
1066+
1067+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1068+
assert np.array_equal(dpt.asnumpy(dst), expected)
1069+
1070+
1071+
def test_setitem_broadcasting_src_ndim_greater_than_dst_ndim():
1072+
get_queue_or_skip()
1073+
dst = dpt.ones((2, 3, 4), dtype="i4")
1074+
src = dpt.zeros((1, 2, 1, 4), dtype="i4")
1075+
dst[...] = src
1076+
1077+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1078+
assert np.array_equal(dpt.asnumpy(dst), expected)
1079+
1080+
10151081
@pytest.mark.parametrize(
10161082
"dtype",
10171083
_all_dtypes,

0 commit comments

Comments
 (0)