@@ -928,20 +928,26 @@ def repeat(x, repeats, axis=None):
928
928
Args:
929
929
x (usm_ndarray): input array
930
930
931
- repeat (Union[int, Tuple [int, ...]]):
931
+ repeats (Union[int, Sequence [int, ...], usm_ndarray ]):
932
932
The number of repetitions for each element.
933
- `repeats` is broadcasted to fit the shape of the given axis.
933
+ `repeats` is broadcast to fit the shape of the given axis.
934
+ If `repeats` is an array, it must have an integer data type.
935
+ Otherwise, `repeats` must be a Python integer, tuple, list, or
936
+ range.
934
937
935
938
axis (Optional[int]):
936
- The axis along which to repeat values. The `axis` is required
937
- if input array has more than one dimension.
939
+ The axis along which to repeat values. If `axis` is `None`, the
940
+ function repeats elements of the flattened array.
941
+ Default: `None`.
938
942
939
943
Returns:
940
944
usm_narray:
941
945
Array with repeated elements.
942
- The returned array must have the same data type as `x`,
943
- is created on the same device as `x` and has the same USM
944
- allocation type as `x`.
946
+ The returned array must have the same data type as `x`, is created
947
+ on the same device as `x` and has the same USM allocation type as
948
+ `x`. If `axis` is `None`, the returned array is one-dimensional,
949
+ otherwise, it has the same shape as `x`, except for the axis along
950
+ which elements were repeated.
945
951
946
952
Raises:
947
953
AxisError: if `axis` value is invalid.
@@ -1005,30 +1011,30 @@ def repeat(x, repeats, axis=None):
1005
1011
if not dpt .all (repeats >= 0 ):
1006
1012
raise ValueError ("`repeats` elements must be positive" )
1007
1013
1008
- elif isinstance (repeats , tuple ):
1014
+ elif isinstance (repeats , ( tuple , list , range ) ):
1009
1015
usm_type = x .usm_type
1010
1016
exec_q = x .sycl_queue
1011
1017
1012
1018
len_reps = len (repeats )
1013
- if len_reps != axis_size :
1014
- raise ValueError (
1015
- "`repeats` tuple must have the same length as the repeated "
1016
- "axis"
1017
- )
1018
- elif len_reps == 1 :
1019
+ if len_reps == 1 :
1019
1020
repeats = repeats [0 ]
1020
1021
if repeats < 0 :
1021
1022
raise ValueError ("`repeats` elements must be positive" )
1022
1023
scalar = True
1023
1024
else :
1025
+ if len_reps != axis_size :
1026
+ raise ValueError (
1027
+ "`repeats` sequence must have the same length as the "
1028
+ "repeated axis"
1029
+ )
1024
1030
repeats = dpt .asarray (
1025
1031
repeats , dtype = dpt .int64 , usm_type = usm_type , sycl_queue = exec_q
1026
1032
)
1027
1033
if not dpt .all (repeats >= 0 ):
1028
1034
raise ValueError ("`repeats` elements must be positive" )
1029
1035
else :
1030
1036
raise TypeError (
1031
- "Expected int, tuple , or `usm_ndarray` for second argument,"
1037
+ "Expected int, sequence , or `usm_ndarray` for second argument,"
1032
1038
f"got { type (repeats )} "
1033
1039
)
1034
1040
0 commit comments