Skip to content

Commit 4a59cfb

Browse files
committed
repeat repeats parameter relaxed to permit lists and ranges
Docstring has been adjusted to reflect changes to `axis` as well as new `repeats` types Corrected a bug in the behavior of `repeat` for size 1 `repeats` Python sequences
1 parent 1b632b8 commit 4a59cfb

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -928,20 +928,26 @@ def repeat(x, repeats, axis=None):
928928
Args:
929929
x (usm_ndarray): input array
930930
931-
repeat (Union[int, Tuple[int, ...]]):
931+
repeats (Union[int, Sequence[int, ...], usm_ndarray]):
932932
The number of repetitions for each element.
933-
`repeats` is broadcasted to fit the shape of the given axis.
933+
`repeats` is broadcast to fit the shape of the given axis.
934+
If `repeats` is an array, it must have an integer data type.
935+
Otherwise, `repeats` must be a Python integer, tuple, list, or
936+
range.
934937
935938
axis (Optional[int]):
936-
The axis along which to repeat values. The `axis` is required
937-
if input array has more than one dimension.
939+
The axis along which to repeat values. If `axis` is `None`, the
940+
function repeats elements of the flattened array.
941+
Default: `None`.
938942
939943
Returns:
940944
usm_narray:
941945
Array with repeated elements.
942-
The returned array must have the same data type as `x`,
943-
is created on the same device as `x` and has the same USM
944-
allocation type as `x`.
946+
The returned array must have the same data type as `x`, is created
947+
on the same device as `x` and has the same USM allocation type as
948+
`x`. If `axis` is `None`, the returned array is one-dimensional,
949+
otherwise, it has the same shape as `x`, except for the axis along
950+
which elements were repeated.
945951
946952
Raises:
947953
AxisError: if `axis` value is invalid.
@@ -1005,30 +1011,30 @@ def repeat(x, repeats, axis=None):
10051011
if not dpt.all(repeats >= 0):
10061012
raise ValueError("`repeats` elements must be positive")
10071013

1008-
elif isinstance(repeats, tuple):
1014+
elif isinstance(repeats, (tuple, list, range)):
10091015
usm_type = x.usm_type
10101016
exec_q = x.sycl_queue
10111017

10121018
len_reps = len(repeats)
1013-
if len_reps != axis_size:
1014-
raise ValueError(
1015-
"`repeats` tuple must have the same length as the repeated "
1016-
"axis"
1017-
)
1018-
elif len_reps == 1:
1019+
if len_reps == 1:
10191020
repeats = repeats[0]
10201021
if repeats < 0:
10211022
raise ValueError("`repeats` elements must be positive")
10221023
scalar = True
10231024
else:
1025+
if len_reps != axis_size:
1026+
raise ValueError(
1027+
"`repeats` sequence must have the same length as the "
1028+
"repeated axis"
1029+
)
10241030
repeats = dpt.asarray(
10251031
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
10261032
)
10271033
if not dpt.all(repeats >= 0):
10281034
raise ValueError("`repeats` elements must be positive")
10291035
else:
10301036
raise TypeError(
1031-
"Expected int, tuple, or `usm_ndarray` for second argument,"
1037+
"Expected int, sequence, or `usm_ndarray` for second argument,"
10321038
f"got {type(repeats)}"
10331039
)
10341040

0 commit comments

Comments
 (0)