|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 | 17 |
|
| 18 | +import itertools |
18 | 19 | import operator
|
19 |
| -from itertools import chain, repeat |
20 | 20 |
|
21 | 21 | import numpy as np
|
| 22 | +from numpy import AxisError |
22 | 23 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
23 | 24 |
|
24 | 25 | import dpctl
|
@@ -133,7 +134,9 @@ def _broadcast_shape_impl(shapes):
|
133 | 134 | diff = biggest - nds[i]
|
134 | 135 | if diff > 0:
|
135 | 136 | ty = type(shapes[i])
|
136 |
| - shapes[i] = ty(chain(repeat(1, diff), shapes[i])) |
| 137 | + shapes[i] = ty( |
| 138 | + itertools.chain(itertools.repeat(1, diff), shapes[i]) |
| 139 | + ) |
137 | 140 | common_shape = []
|
138 | 141 | for axis in range(biggest):
|
139 | 142 | lengths = [s[axis] for s in shapes]
|
@@ -918,6 +921,302 @@ def swapaxes(X, axis1, axis2):
|
918 | 921 | return dpt.permute_dims(X, tuple(ind))
|
919 | 922 |
|
920 | 923 |
|
| 924 | +def repeat(x, repeats, axis=None): |
| 925 | + """repeat(x, repeats, axis=None) |
| 926 | +
|
| 927 | + Repeat elements of an array. |
| 928 | +
|
| 929 | + Args: |
| 930 | + x (usm_ndarray): input array |
| 931 | +
|
| 932 | + repeat (Union[int, Tuple[int, ...]]): |
| 933 | + The number of repetitions for each element. |
| 934 | + `repeats` is broadcasted to fit the shape of the given axis. |
| 935 | +
|
| 936 | + axis (Optional[int]): |
| 937 | + The axis along which to repeat values. The `axis` is required |
| 938 | + if input array has more than one dimension. |
| 939 | +
|
| 940 | + Returns: |
| 941 | + usm_narray: |
| 942 | + Array with repeated elements. |
| 943 | + The returned array must have the same data type as `x`, |
| 944 | + is created on the same device as `x` and has the same USM |
| 945 | + allocation type as `x`. |
| 946 | +
|
| 947 | + Raises: |
| 948 | + AxisError: if `axis` value is invalid. |
| 949 | + """ |
| 950 | + if not isinstance(x, dpt.usm_ndarray): |
| 951 | + raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") |
| 952 | + |
| 953 | + x_ndim = x.ndim |
| 954 | + if axis is None: |
| 955 | + if x_ndim > 1: |
| 956 | + raise ValueError( |
| 957 | + f"`axis` cannot be `None` for array of dimension {x_ndim}" |
| 958 | + ) |
| 959 | + axis = 0 |
| 960 | + |
| 961 | + x_shape = x.shape |
| 962 | + if x_ndim > 0: |
| 963 | + axis = normalize_axis_index(operator.index(axis), x_ndim) |
| 964 | + axis_size = x_shape[axis] |
| 965 | + else: |
| 966 | + if axis != 0: |
| 967 | + AxisError("`axis` must be `0` for input of dimension `0`") |
| 968 | + axis_size = x.size |
| 969 | + |
| 970 | + scalar = False |
| 971 | + if isinstance(repeats, int): |
| 972 | + if repeats < 0: |
| 973 | + raise ValueError("`repeats` must be a positive integer") |
| 974 | + usm_type = x.usm_type |
| 975 | + exec_q = x.sycl_queue |
| 976 | + scalar = True |
| 977 | + elif isinstance(repeats, dpt.usm_ndarray): |
| 978 | + if repeats.ndim > 1: |
| 979 | + raise ValueError( |
| 980 | + "`repeats` array must be 0- or 1-dimensional, got" |
| 981 | + "{repeats.ndim}" |
| 982 | + ) |
| 983 | + exec_q = dpctl.utils.get_execution_queue( |
| 984 | + (x.sycl_queue, repeats.sycl_queue) |
| 985 | + ) |
| 986 | + if exec_q is None: |
| 987 | + raise dputils.ExecutionPlacementError( |
| 988 | + "Execution placement can not be unambiguously inferred " |
| 989 | + "from input arguments." |
| 990 | + ) |
| 991 | + usm_type = dpctl.utils.get_coerced_usm_type( |
| 992 | + ( |
| 993 | + x.usm_type, |
| 994 | + repeats.usm_type, |
| 995 | + ) |
| 996 | + ) |
| 997 | + dpctl.utils.validate_usm_type(usm_type, allow_none=False) |
| 998 | + if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"): |
| 999 | + raise TypeError( |
| 1000 | + f"`repeats` data type `{repeats.dtype}` cannot be cast to " |
| 1001 | + "`int64` according to the casting rule ''safe.''" |
| 1002 | + ) |
| 1003 | + if repeats.size == 1: |
| 1004 | + scalar = True |
| 1005 | + # bring the single element to the host |
| 1006 | + repeats = int(repeats) |
| 1007 | + if repeats < 0: |
| 1008 | + raise ValueError("`repeats` elements must be positive") |
| 1009 | + else: |
| 1010 | + if repeats.size != axis_size: |
| 1011 | + raise ValueError( |
| 1012 | + "`repeats` array must be broadcastable to the size of " |
| 1013 | + "the repeated axis" |
| 1014 | + ) |
| 1015 | + if not dpt.all(repeats >= 0): |
| 1016 | + raise ValueError("`repeats` elements must be positive") |
| 1017 | + |
| 1018 | + elif isinstance(repeats, tuple): |
| 1019 | + usm_type = x.usm_type |
| 1020 | + exec_q = x.sycl_queue |
| 1021 | + |
| 1022 | + len_reps = len(repeats) |
| 1023 | + if len_reps != axis_size: |
| 1024 | + raise ValueError( |
| 1025 | + "`repeats` tuple must have the same length as the repeated " |
| 1026 | + "axis" |
| 1027 | + ) |
| 1028 | + elif len_reps == 1: |
| 1029 | + repeats = repeats[0] |
| 1030 | + if repeats < 0: |
| 1031 | + raise ValueError("`repeats` elements must be positive") |
| 1032 | + scalar = True |
| 1033 | + else: |
| 1034 | + repeats = dpt.asarray( |
| 1035 | + repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q |
| 1036 | + ) |
| 1037 | + if not dpt.all(repeats >= 0): |
| 1038 | + raise ValueError("`repeats` elements must be positive") |
| 1039 | + else: |
| 1040 | + raise TypeError( |
| 1041 | + "Expected int, tuple, or `usm_ndarray` for second argument," |
| 1042 | + f"got {type(repeats)}" |
| 1043 | + ) |
| 1044 | + |
| 1045 | + if axis_size == 0: |
| 1046 | + return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q) |
| 1047 | + |
| 1048 | + if scalar: |
| 1049 | + res_axis_size = repeats * axis_size |
| 1050 | + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] |
| 1051 | + res = dpt.empty( |
| 1052 | + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q |
| 1053 | + ) |
| 1054 | + if res_axis_size > 0: |
| 1055 | + ht_rep_ev, _ = ti._repeat_by_scalar( |
| 1056 | + src=x, |
| 1057 | + dst=res, |
| 1058 | + reps=repeats, |
| 1059 | + axis=axis, |
| 1060 | + sycl_queue=exec_q, |
| 1061 | + ) |
| 1062 | + ht_rep_ev.wait() |
| 1063 | + else: |
| 1064 | + if repeats.dtype != dpt.int64: |
| 1065 | + rep_buf = dpt.empty( |
| 1066 | + repeats.shape, |
| 1067 | + dtype=dpt.int64, |
| 1068 | + usm_type=usm_type, |
| 1069 | + sycl_queue=exec_q, |
| 1070 | + ) |
| 1071 | + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
| 1072 | + src=repeats, dst=rep_buf, sycl_queue=exec_q |
| 1073 | + ) |
| 1074 | + cumsum = dpt.empty( |
| 1075 | + (axis_size,), |
| 1076 | + dtype=dpt.int64, |
| 1077 | + usm_type=usm_type, |
| 1078 | + sycl_queue=exec_q, |
| 1079 | + ) |
| 1080 | + # _cumsum_1d synchronizes so `depends` ends here safely |
| 1081 | + res_axis_size = ti._cumsum_1d( |
| 1082 | + rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev] |
| 1083 | + ) |
| 1084 | + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] |
| 1085 | + res = dpt.empty( |
| 1086 | + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q |
| 1087 | + ) |
| 1088 | + if res_axis_size > 0: |
| 1089 | + ht_rep_ev, _ = ti._repeat_by_sequence( |
| 1090 | + src=x, |
| 1091 | + dst=res, |
| 1092 | + reps=rep_buf, |
| 1093 | + cumsum=cumsum, |
| 1094 | + axis=axis, |
| 1095 | + sycl_queue=exec_q, |
| 1096 | + ) |
| 1097 | + ht_rep_ev.wait() |
| 1098 | + ht_copy_ev.wait() |
| 1099 | + else: |
| 1100 | + cumsum = dpt.empty( |
| 1101 | + (axis_size,), |
| 1102 | + dtype=dpt.int64, |
| 1103 | + usm_type=usm_type, |
| 1104 | + sycl_queue=exec_q, |
| 1105 | + ) |
| 1106 | + # _cumsum_1d synchronizes so `depends` ends here safely |
| 1107 | + res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q) |
| 1108 | + res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :] |
| 1109 | + res = dpt.empty( |
| 1110 | + res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q |
| 1111 | + ) |
| 1112 | + if res_axis_size > 0: |
| 1113 | + ht_rep_ev, _ = ti._repeat_by_sequence( |
| 1114 | + src=x, |
| 1115 | + dst=res, |
| 1116 | + reps=repeats, |
| 1117 | + cumsum=cumsum, |
| 1118 | + axis=axis, |
| 1119 | + sycl_queue=exec_q, |
| 1120 | + ) |
| 1121 | + ht_rep_ev.wait() |
| 1122 | + return res |
| 1123 | + |
| 1124 | + |
| 1125 | +def tile(x, repetitions): |
| 1126 | + """tile(x, repetitions) |
| 1127 | +
|
| 1128 | + Repeat an input array `x` along each axis a number of times given by |
| 1129 | + `repetitions`. |
| 1130 | +
|
| 1131 | + For `N` = len(`repetitions`) and `M` = len(`x.shape`): |
| 1132 | + - if `M < N`, `x` will have `N - M` new axes prepended to its shape |
| 1133 | + - if `M > N`, `repetitions` will have `M - N` new axes 1 prepended to it |
| 1134 | +
|
| 1135 | + Args: |
| 1136 | + x (usm_ndarray): input array |
| 1137 | +
|
| 1138 | + repetitions (Union[int, Tuple[int, ...]]): |
| 1139 | + The number of repetitions for each dimension. |
| 1140 | +
|
| 1141 | + Returns: |
| 1142 | + usm_narray: |
| 1143 | + Array with tiled elements. |
| 1144 | + The returned array must have the same data type as `x`, |
| 1145 | + is created on the same device as `x` and has the same USM |
| 1146 | + allocation type as `x`. |
| 1147 | + """ |
| 1148 | + if not isinstance(x, dpt.usm_ndarray): |
| 1149 | + raise TypeError(f"Expected usm_ndarray type, got {type(x)}.") |
| 1150 | + |
| 1151 | + if not isinstance(repetitions, tuple): |
| 1152 | + if isinstance(repetitions, int): |
| 1153 | + repetitions = (repetitions,) |
| 1154 | + else: |
| 1155 | + raise TypeError( |
| 1156 | + f"Expected tuple or integer type, got {type(repetitions)}." |
| 1157 | + ) |
| 1158 | + |
| 1159 | + # case of scalar |
| 1160 | + if x.size == 1: |
| 1161 | + if not repetitions: |
| 1162 | + # handle empty tuple |
| 1163 | + repetitions = (1,) |
| 1164 | + return dpt.full( |
| 1165 | + repetitions, |
| 1166 | + x, |
| 1167 | + dtype=x.dtype, |
| 1168 | + usm_type=x.usm_type, |
| 1169 | + sycl_queue=x.sycl_queue, |
| 1170 | + ) |
| 1171 | + rep_dims = len(repetitions) |
| 1172 | + x_dims = x.ndim |
| 1173 | + if rep_dims < x_dims: |
| 1174 | + repetitions = (x_dims - rep_dims) * (1,) + repetitions |
| 1175 | + elif x_dims < rep_dims: |
| 1176 | + x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape) |
| 1177 | + res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions)) |
| 1178 | + # case of empty input |
| 1179 | + if x.size == 0: |
| 1180 | + return dpt.empty( |
| 1181 | + res_shape, x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue |
| 1182 | + ) |
| 1183 | + in_sh = x.shape |
| 1184 | + if res_shape == in_sh: |
| 1185 | + return dpt.copy(x) |
| 1186 | + expanded_sh = [] |
| 1187 | + broadcast_sh = [] |
| 1188 | + out_sz = 1 |
| 1189 | + for i in range(len(res_shape)): |
| 1190 | + out_sz *= res_shape[i] |
| 1191 | + reps, sh = repetitions[i], in_sh[i] |
| 1192 | + if reps == 1: |
| 1193 | + # dimension will be unchanged |
| 1194 | + broadcast_sh.append(sh) |
| 1195 | + expanded_sh.append(sh) |
| 1196 | + elif sh == 1: |
| 1197 | + # dimension will be broadcast |
| 1198 | + broadcast_sh.append(reps) |
| 1199 | + expanded_sh.append(sh) |
| 1200 | + else: |
| 1201 | + broadcast_sh.extend([reps, sh]) |
| 1202 | + expanded_sh.extend([1, sh]) |
| 1203 | + exec_q = x.sycl_queue |
| 1204 | + res = dpt.empty((out_sz,), x.dtype, usm_type=x.usm_type, sycl_queue=exec_q) |
| 1205 | + # no need to copy data for empty output |
| 1206 | + if out_sz > 0: |
| 1207 | + x = dpt.broadcast_to( |
| 1208 | + # this reshape should never copy |
| 1209 | + dpt.reshape(x, expanded_sh), |
| 1210 | + broadcast_sh, |
| 1211 | + ) |
| 1212 | + # copy broadcast input into flat array |
| 1213 | + hev, _ = ti._copy_usm_ndarray_for_reshape( |
| 1214 | + src=x, dst=res, sycl_queue=exec_q |
| 1215 | + ) |
| 1216 | + hev.wait() |
| 1217 | + return dpt.reshape(res, res_shape) |
| 1218 | + |
| 1219 | + |
921 | 1220 | def _supported_dtype(dtypes):
|
922 | 1221 | for dtype in dtypes:
|
923 | 1222 | if dtype.char not in "?bBhHiIlLqQefdFD":
|
|
0 commit comments