Skip to content

Commit c368034

Browse files
Fix for gh-1330
When source array must be broadcast, non-zero strides corresponding to unit size dimensions must be zeroed out, otherwise if such dimension is broadcasted, constructor would expected a bigger buffer than is really necessary.
1 parent f77f7a4 commit c368034

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,10 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
279279
common_shape = common_shape[ones_count:]
280280

281281
if src.ndim < len(common_shape):
282-
new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides
282+
pad_count = len(common_shape) - src.ndim
283+
new_src_strides = (0,) * pad_count + tuple(
284+
s if d > 1 else 0 for s, d in zip(src.strides, src.shape)
285+
)
283286
src_same_shape = dpt.usm_ndarray(
284287
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
285288
)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,15 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
10121012
Zusm_empty[Ellipsis] = Zusm_3d[0, 0, 0:0]
10131013

10141014

1015+
def test_setitem_boradcasting():
1016+
get_queue_or_skip()
1017+
dst = dpt.ones((2, 3, 4), dtype="u4")
1018+
src = dpt.zeros((3, 1), dtype=dst.dtype)
1019+
dst[...] = src
1020+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1021+
assert np.array_equal(dpt.asnumpy(dst), expected)
1022+
1023+
10151024
@pytest.mark.parametrize(
10161025
"dtype",
10171026
_all_dtypes,

0 commit comments

Comments
 (0)