Skip to content

Commit 52a44a9

Browse files
committed
repeat now permits 0d arrays when axis is None
- Also adds a check that the sole element of a length 1 tuple is an integer before proceeding to the scalar case
1 parent 160dac5 commit 52a44a9

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def repeat(x, repeats, axis=None):
962962

963963
x_ndim = x.ndim
964964
if axis is None:
965-
if x_ndim != 1:
965+
if x_ndim > 1:
966966
raise ValueError(
967967
f"`axis` cannot be `None` for array of dimension {x_ndim}"
968968
)
@@ -990,6 +990,8 @@ def repeat(x, repeats, axis=None):
990990
)
991991
elif len_reps == 1:
992992
repeats = repeats[0]
993+
if not isinstance(repeats, int):
994+
raise TypeError("`repeats` elements must be integers")
993995
if repeats < 0:
994996
raise ValueError("`repeats` elements must be positive")
995997
scalar = True

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ def test_repeat_arg_validation():
12191219
with pytest.raises(ValueError):
12201220
dpt.repeat(x, 2, axis=1)
12211221

1222-
# x.ndim must be 1 for axis=None
1222+
# x.ndim cannot be > 1 for axis=None
12231223
x = dpt.empty((5, 10))
12241224
with pytest.raises(ValueError):
12251225
dpt.repeat(x, 2, axis=None)
@@ -1233,6 +1233,10 @@ def test_repeat_arg_validation():
12331233
with pytest.raises(TypeError):
12341234
dpt.repeat(x, 2.0)
12351235

1236+
# repeats tuple elements must be integers
1237+
with pytest.raises(TypeError):
1238+
dpt.repeat(x, (2.0,))
1239+
12361240
# repeats tuple must be the same length as axis
12371241
with pytest.raises(ValueError):
12381242
dpt.repeat(x, (1, 2))

0 commit comments

Comments
 (0)