|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 | 17 |
|
18 |
| -from itertools import chain, product, repeat |
| 18 | +import itertools |
| 19 | +import operator |
| 20 | +from itertools import chain, product |
19 | 21 |
|
20 | 22 | import numpy as np
|
| 23 | +from numpy import AxisError |
21 | 24 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
22 | 25 |
|
23 | 26 | import dpctl
|
@@ -132,7 +135,7 @@ def _broadcast_shape_impl(shapes):
|
132 | 135 | diff = biggest - nds[i]
|
133 | 136 | if diff > 0:
|
134 | 137 | ty = type(shapes[i])
|
135 |
| - shapes[i] = ty(chain(repeat(1, diff), shapes[i])) |
| 138 | + shapes[i] = ty(chain(itertools.repeat(1, diff), shapes[i])) |
136 | 139 | common_shape = []
|
137 | 140 | for axis in range(biggest):
|
138 | 141 | lengths = [s[axis] for s in shapes]
|
@@ -928,6 +931,94 @@ def swapaxes(X, axis1, axis2):
|
928 | 931 | return dpt.permute_dims(X, tuple(ind))
|
929 | 932 |
|
930 | 933 |
|
| 934 | +def repeat(x, repeats, axis=None): |
| 935 | + if not isinstance(x, dpt.usm_ndarray): |
| 936 | + raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") |
| 937 | + |
| 938 | + x_ndim = x.ndim |
| 939 | + if axis is None: |
| 940 | + if x_ndim != 1: |
| 941 | + raise ValueError( |
| 942 | + f"`axis` cannot be `None` for array of dimension {x_ndim}" |
| 943 | + ) |
| 944 | + axis = 0 |
| 945 | + |
| 946 | + x_shape = x.shape |
| 947 | + if x_ndim > 0: |
| 948 | + axis = normalize_axis_index(operator.index(axis), x_ndim) |
| 949 | + axis_size = x_shape[axis] |
| 950 | + else: |
| 951 | + if axis != 0: |
| 952 | + AxisError("`axis` must be `0` for input of dimension `0`") |
| 953 | + axis_size = x.size |
| 954 | + |
| 955 | + scalar = False |
| 956 | + if isinstance(repeats, int): |
| 957 | + scalar = True |
| 958 | + if repeats < 0: |
| 959 | + raise ValueError("`repeats` must be a positive integer") |
| 960 | + elif isinstance(repeats, tuple): |
| 961 | + len_reps = len(repeats) |
| 962 | + if len_reps != axis_size: |
| 963 | + raise ValueError( |
| 964 | + "`repeats` tuple must have the same length as the repeated axis" |
| 965 | + ) |
| 966 | + elif len_reps == 1: |
| 967 | + repeats = repeats[0] |
| 968 | + if repeats < 0: |
| 969 | + raise ValueError("`repeats` elements must be positive") |
| 970 | + scalar = True |
| 971 | + else: |
| 972 | + raise TypeError( |
| 973 | + f"Expected int or tuple for second argument, got {type(repeats)}" |
| 974 | + ) |
| 975 | + |
| 976 | + usm_type = x.usm_type |
| 977 | + exec_q = x.sycl_queue |
| 978 | + |
| 979 | + if axis_size == 0: |
| 980 | + return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q) |
| 981 | + |
| 982 | + if scalar: |
| 983 | + res_axis_size = repeats * axis_size |
| 984 | + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] |
| 985 | + res = dpt.empty( |
| 986 | + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q |
| 987 | + ) |
| 988 | + if res_axis_size > 0: |
| 989 | + hev, _ = ti._repeat_by_scalar( |
| 990 | + src=x, |
| 991 | + dst=res, |
| 992 | + reps=repeats, |
| 993 | + axis=axis, |
| 994 | + sycl_queue=exec_q, |
| 995 | + ) |
| 996 | + hev.wait() |
| 997 | + else: |
| 998 | + repeats = dpt.asarray(repeats, dtype="i8", sycl_queue=exec_q) |
| 999 | + if not dpt.all(repeats >= 0): |
| 1000 | + raise ValueError("`repeats` elements must be positive") |
| 1001 | + cumsum = dpt.empty( |
| 1002 | + (axis_size,), dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q |
| 1003 | + ) |
| 1004 | + res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q) |
| 1005 | + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] |
| 1006 | + res = dpt.empty( |
| 1007 | + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q |
| 1008 | + ) |
| 1009 | + if res_axis_size > 0: |
| 1010 | + hev, _ = ti._repeat_by_sequence( |
| 1011 | + src=x, |
| 1012 | + dst=res, |
| 1013 | + reps=repeats, |
| 1014 | + cumsum=cumsum, |
| 1015 | + axis=axis, |
| 1016 | + sycl_queue=exec_q, |
| 1017 | + ) |
| 1018 | + hev.wait() |
| 1019 | + return res |
| 1020 | + |
| 1021 | + |
931 | 1022 | def _supported_dtype(dtypes):
|
932 | 1023 | for dtype in dtypes:
|
933 | 1024 | if dtype.char not in "?bBhHiIlLqQefdFD":
|
|
0 commit comments