Skip to content

Commit 83fff33

Browse files
Implements dpctl.tensor.repeat, dpctl.tensor.tile (#1381)
* Factored mask_positions implementation and kernels into separate files Doing this will make implementing more accumulators convenient * Moved split_iteration_space into simplify_iteration_space.cpp * Implements _cumsum_1d * Implements dpctl.tensor.repeat * Adds tests for dpctl.tensor.repeat * Added repeat.__docstring__ * repeat now permits 0d arrays when `axis` is `None` - Also adds a check that the sole element of a length 1 tuple is an integer before proceeding to the scalar case * Implemented usm_ndarray `repeats` inputs for repeat * Implements dpctl.tensor.tile --------- Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 51d994a commit 83fff33

16 files changed

+2571
-529
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ endif()
3333
set(python_module_name _tensor_impl)
3434
pybind11_add_module(${python_module_name} MODULE
3535
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
36+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
3839
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
@@ -49,6 +50,7 @@ pybind11_add_module(${python_module_name} MODULE
4950
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
5051
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
5152
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
5254
)
5355
set(_clang_prefix "")
5456
if (WIN32)

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@
7272
iinfo,
7373
moveaxis,
7474
permute_dims,
75+
repeat,
7576
result_type,
7677
roll,
7778
squeeze,
7879
stack,
7980
swapaxes,
81+
tile,
8082
unstack,
8183
)
8284
from dpctl.tensor._print import (
@@ -305,4 +307,6 @@
305307
"tanh",
306308
"trunc",
307309
"allclose",
310+
"repeat",
311+
"tile",
308312
]

dpctl/tensor/_manipulation_functions.py

Lines changed: 301 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
# limitations under the License.
1616

1717

18+
import itertools
1819
import operator
19-
from itertools import chain, repeat
2020

2121
import numpy as np
22+
from numpy import AxisError
2223
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
2324

2425
import dpctl
@@ -133,7 +134,9 @@ def _broadcast_shape_impl(shapes):
133134
diff = biggest - nds[i]
134135
if diff > 0:
135136
ty = type(shapes[i])
136-
shapes[i] = ty(chain(repeat(1, diff), shapes[i]))
137+
shapes[i] = ty(
138+
itertools.chain(itertools.repeat(1, diff), shapes[i])
139+
)
137140
common_shape = []
138141
for axis in range(biggest):
139142
lengths = [s[axis] for s in shapes]
@@ -918,6 +921,302 @@ def swapaxes(X, axis1, axis2):
918921
return dpt.permute_dims(X, tuple(ind))
919922

920923

924+
def repeat(x, repeats, axis=None):
925+
"""repeat(x, repeats, axis=None)
926+
927+
Repeat elements of an array.
928+
929+
Args:
930+
x (usm_ndarray): input array
931+
932+
repeat (Union[int, Tuple[int, ...]]):
933+
The number of repetitions for each element.
934+
`repeats` is broadcasted to fit the shape of the given axis.
935+
936+
axis (Optional[int]):
937+
The axis along which to repeat values. The `axis` is required
938+
if input array has more than one dimension.
939+
940+
Returns:
941+
usm_narray:
942+
Array with repeated elements.
943+
The returned array must have the same data type as `x`,
944+
is created on the same device as `x` and has the same USM
945+
allocation type as `x`.
946+
947+
Raises:
948+
AxisError: if `axis` value is invalid.
949+
"""
950+
if not isinstance(x, dpt.usm_ndarray):
951+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
952+
953+
x_ndim = x.ndim
954+
if axis is None:
955+
if x_ndim > 1:
956+
raise ValueError(
957+
f"`axis` cannot be `None` for array of dimension {x_ndim}"
958+
)
959+
axis = 0
960+
961+
x_shape = x.shape
962+
if x_ndim > 0:
963+
axis = normalize_axis_index(operator.index(axis), x_ndim)
964+
axis_size = x_shape[axis]
965+
else:
966+
if axis != 0:
967+
AxisError("`axis` must be `0` for input of dimension `0`")
968+
axis_size = x.size
969+
970+
scalar = False
971+
if isinstance(repeats, int):
972+
if repeats < 0:
973+
raise ValueError("`repeats` must be a positive integer")
974+
usm_type = x.usm_type
975+
exec_q = x.sycl_queue
976+
scalar = True
977+
elif isinstance(repeats, dpt.usm_ndarray):
978+
if repeats.ndim > 1:
979+
raise ValueError(
980+
"`repeats` array must be 0- or 1-dimensional, got"
981+
"{repeats.ndim}"
982+
)
983+
exec_q = dpctl.utils.get_execution_queue(
984+
(x.sycl_queue, repeats.sycl_queue)
985+
)
986+
if exec_q is None:
987+
raise dputils.ExecutionPlacementError(
988+
"Execution placement can not be unambiguously inferred "
989+
"from input arguments."
990+
)
991+
usm_type = dpctl.utils.get_coerced_usm_type(
992+
(
993+
x.usm_type,
994+
repeats.usm_type,
995+
)
996+
)
997+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
998+
if not dpt.can_cast(repeats.dtype, dpt.int64, casting="same_kind"):
999+
raise TypeError(
1000+
f"`repeats` data type `{repeats.dtype}` cannot be cast to "
1001+
"`int64` according to the casting rule ''safe.''"
1002+
)
1003+
if repeats.size == 1:
1004+
scalar = True
1005+
# bring the single element to the host
1006+
repeats = int(repeats)
1007+
if repeats < 0:
1008+
raise ValueError("`repeats` elements must be positive")
1009+
else:
1010+
if repeats.size != axis_size:
1011+
raise ValueError(
1012+
"`repeats` array must be broadcastable to the size of "
1013+
"the repeated axis"
1014+
)
1015+
if not dpt.all(repeats >= 0):
1016+
raise ValueError("`repeats` elements must be positive")
1017+
1018+
elif isinstance(repeats, tuple):
1019+
usm_type = x.usm_type
1020+
exec_q = x.sycl_queue
1021+
1022+
len_reps = len(repeats)
1023+
if len_reps != axis_size:
1024+
raise ValueError(
1025+
"`repeats` tuple must have the same length as the repeated "
1026+
"axis"
1027+
)
1028+
elif len_reps == 1:
1029+
repeats = repeats[0]
1030+
if repeats < 0:
1031+
raise ValueError("`repeats` elements must be positive")
1032+
scalar = True
1033+
else:
1034+
repeats = dpt.asarray(
1035+
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
1036+
)
1037+
if not dpt.all(repeats >= 0):
1038+
raise ValueError("`repeats` elements must be positive")
1039+
else:
1040+
raise TypeError(
1041+
"Expected int, tuple, or `usm_ndarray` for second argument,"
1042+
f"got {type(repeats)}"
1043+
)
1044+
1045+
if axis_size == 0:
1046+
return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q)
1047+
1048+
if scalar:
1049+
res_axis_size = repeats * axis_size
1050+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1051+
res = dpt.empty(
1052+
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1053+
)
1054+
if res_axis_size > 0:
1055+
ht_rep_ev, _ = ti._repeat_by_scalar(
1056+
src=x,
1057+
dst=res,
1058+
reps=repeats,
1059+
axis=axis,
1060+
sycl_queue=exec_q,
1061+
)
1062+
ht_rep_ev.wait()
1063+
else:
1064+
if repeats.dtype != dpt.int64:
1065+
rep_buf = dpt.empty(
1066+
repeats.shape,
1067+
dtype=dpt.int64,
1068+
usm_type=usm_type,
1069+
sycl_queue=exec_q,
1070+
)
1071+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1072+
src=repeats, dst=rep_buf, sycl_queue=exec_q
1073+
)
1074+
cumsum = dpt.empty(
1075+
(axis_size,),
1076+
dtype=dpt.int64,
1077+
usm_type=usm_type,
1078+
sycl_queue=exec_q,
1079+
)
1080+
# _cumsum_1d synchronizes so `depends` ends here safely
1081+
res_axis_size = ti._cumsum_1d(
1082+
rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
1083+
)
1084+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1085+
res = dpt.empty(
1086+
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1087+
)
1088+
if res_axis_size > 0:
1089+
ht_rep_ev, _ = ti._repeat_by_sequence(
1090+
src=x,
1091+
dst=res,
1092+
reps=rep_buf,
1093+
cumsum=cumsum,
1094+
axis=axis,
1095+
sycl_queue=exec_q,
1096+
)
1097+
ht_rep_ev.wait()
1098+
ht_copy_ev.wait()
1099+
else:
1100+
cumsum = dpt.empty(
1101+
(axis_size,),
1102+
dtype=dpt.int64,
1103+
usm_type=usm_type,
1104+
sycl_queue=exec_q,
1105+
)
1106+
# _cumsum_1d synchronizes so `depends` ends here safely
1107+
res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q)
1108+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1109+
res = dpt.empty(
1110+
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1111+
)
1112+
if res_axis_size > 0:
1113+
ht_rep_ev, _ = ti._repeat_by_sequence(
1114+
src=x,
1115+
dst=res,
1116+
reps=repeats,
1117+
cumsum=cumsum,
1118+
axis=axis,
1119+
sycl_queue=exec_q,
1120+
)
1121+
ht_rep_ev.wait()
1122+
return res
1123+
1124+
1125+
def tile(x, repetitions):
1126+
"""tile(x, repetitions)
1127+
1128+
Repeat an input array `x` along each axis a number of times given by
1129+
`repetitions`.
1130+
1131+
For `N` = len(`repetitions`) and `M` = len(`x.shape`):
1132+
- if `M < N`, `x` will have `N - M` new axes prepended to its shape
1133+
- if `M > N`, `repetitions` will have `M - N` new axes 1 prepended to it
1134+
1135+
Args:
1136+
x (usm_ndarray): input array
1137+
1138+
repetitions (Union[int, Tuple[int, ...]]):
1139+
The number of repetitions for each dimension.
1140+
1141+
Returns:
1142+
usm_narray:
1143+
Array with tiled elements.
1144+
The returned array must have the same data type as `x`,
1145+
is created on the same device as `x` and has the same USM
1146+
allocation type as `x`.
1147+
"""
1148+
if not isinstance(x, dpt.usm_ndarray):
1149+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
1150+
1151+
if not isinstance(repetitions, tuple):
1152+
if isinstance(repetitions, int):
1153+
repetitions = (repetitions,)
1154+
else:
1155+
raise TypeError(
1156+
f"Expected tuple or integer type, got {type(repetitions)}."
1157+
)
1158+
1159+
# case of scalar
1160+
if x.size == 1:
1161+
if not repetitions:
1162+
# handle empty tuple
1163+
repetitions = (1,)
1164+
return dpt.full(
1165+
repetitions,
1166+
x,
1167+
dtype=x.dtype,
1168+
usm_type=x.usm_type,
1169+
sycl_queue=x.sycl_queue,
1170+
)
1171+
rep_dims = len(repetitions)
1172+
x_dims = x.ndim
1173+
if rep_dims < x_dims:
1174+
repetitions = (x_dims - rep_dims) * (1,) + repetitions
1175+
elif x_dims < rep_dims:
1176+
x = dpt.reshape(x, (rep_dims - x_dims) * (1,) + x.shape)
1177+
res_shape = tuple(map(lambda sh, rep: sh * rep, x.shape, repetitions))
1178+
# case of empty input
1179+
if x.size == 0:
1180+
return dpt.empty(
1181+
res_shape, x.dtype, usm_type=x.usm_type, sycl_queue=x.sycl_queue
1182+
)
1183+
in_sh = x.shape
1184+
if res_shape == in_sh:
1185+
return dpt.copy(x)
1186+
expanded_sh = []
1187+
broadcast_sh = []
1188+
out_sz = 1
1189+
for i in range(len(res_shape)):
1190+
out_sz *= res_shape[i]
1191+
reps, sh = repetitions[i], in_sh[i]
1192+
if reps == 1:
1193+
# dimension will be unchanged
1194+
broadcast_sh.append(sh)
1195+
expanded_sh.append(sh)
1196+
elif sh == 1:
1197+
# dimension will be broadcast
1198+
broadcast_sh.append(reps)
1199+
expanded_sh.append(sh)
1200+
else:
1201+
broadcast_sh.extend([reps, sh])
1202+
expanded_sh.extend([1, sh])
1203+
exec_q = x.sycl_queue
1204+
res = dpt.empty((out_sz,), x.dtype, usm_type=x.usm_type, sycl_queue=exec_q)
1205+
# no need to copy data for empty output
1206+
if out_sz > 0:
1207+
x = dpt.broadcast_to(
1208+
# this reshape should never copy
1209+
dpt.reshape(x, expanded_sh),
1210+
broadcast_sh,
1211+
)
1212+
# copy broadcast input into flat array
1213+
hev, _ = ti._copy_usm_ndarray_for_reshape(
1214+
src=x, dst=res, sycl_queue=exec_q
1215+
)
1216+
hev.wait()
1217+
return dpt.reshape(res, res_shape)
1218+
1219+
9211220
def _supported_dtype(dtypes):
9221221
for dtype in dtypes:
9231222
if dtype.char not in "?bBhHiIlLqQefdFD":

0 commit comments

Comments
 (0)