Skip to content

Commit 8a706ae

Browse files
authored
Removes malformed fast path for zero-size arrays in repeat (#1682)
* Removes malformed fast path for size-zero arrays in `repeat` * Adds a test for fixed behavior of `repeat` with 0-size arrays
1 parent f3d8ee7 commit 8a706ae

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,6 @@ def repeat(x, repeats, /, *, axis=None):
838838
f"got {type(repeats)}"
839839
)
840840

841-
if axis_size == 0:
842-
return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q)
843-
844841
if scalar:
845842
res_axis_size = repeats * axis_size
846843
if axis is not None:

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,3 +1485,35 @@ def test_tile_arg_validation():
14851485
x = dpt.empty(())
14861486
with pytest.raises(TypeError):
14871487
dpt.tile(x, dict())
1488+
1489+
1490+
def test_repeat_0_size():
1491+
get_queue_or_skip()
1492+
1493+
x = dpt.ones((0, 10, 0), dtype="i4")
1494+
repetitions = 2
1495+
res = dpt.repeat(x, repetitions)
1496+
assert res.shape == (0,)
1497+
res = dpt.repeat(x, repetitions, axis=2)
1498+
assert res.shape == x.shape
1499+
res = dpt.repeat(x, repetitions, axis=1)
1500+
axis_sz = x.shape[1] * repetitions
1501+
assert res.shape == (0, 20, 0)
1502+
1503+
repetitions = dpt.asarray(2, dtype="i4")
1504+
res = dpt.repeat(x, repetitions)
1505+
assert res.shape == (0,)
1506+
res = dpt.repeat(x, repetitions, axis=2)
1507+
assert res.shape == x.shape
1508+
res = dpt.repeat(x, repetitions, axis=1)
1509+
assert res.shape == (0, 20, 0)
1510+
1511+
repetitions = dpt.arange(10, dtype="i4")
1512+
res = dpt.repeat(x, repetitions, axis=1)
1513+
axis_sz = dpt.sum(repetitions)
1514+
assert res.shape == (0, axis_sz, 0)
1515+
1516+
repetitions = (2,) * 10
1517+
res = dpt.repeat(x, repetitions, axis=1)
1518+
axis_sz = 2 * x.shape[1]
1519+
assert res.shape == (0, axis_sz, 0)

0 commit comments

Comments
 (0)