Skip to content

Commit 6972c46

Browse files
committed
Implements dpctl.tensor.repeat
1 parent de79b20 commit 6972c46

File tree

7 files changed

+1146
-2
lines changed

7 files changed

+1146
-2
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pybind11_add_module(${python_module_name} MODULE
4949
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sum_reductions.cpp
52+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
5253
)
5354
set(_clang_prefix "")
5455
if (WIN32)

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
iinfo,
7373
moveaxis,
7474
permute_dims,
75+
repeat,
7576
result_type,
7677
roll,
7778
squeeze,
@@ -305,4 +306,5 @@
305306
"tanh",
306307
"trunc",
307308
"allclose",
309+
"repeat",
308310
]

dpctl/tensor/_manipulation_functions.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616

1717

18-
from itertools import chain, product, repeat
18+
import itertools
19+
import operator
20+
from itertools import chain, product
1921

2022
import numpy as np
23+
from numpy import AxisError
2124
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
2225

2326
import dpctl
@@ -132,7 +135,7 @@ def _broadcast_shape_impl(shapes):
132135
diff = biggest - nds[i]
133136
if diff > 0:
134137
ty = type(shapes[i])
135-
shapes[i] = ty(chain(repeat(1, diff), shapes[i]))
138+
shapes[i] = ty(chain(itertools.repeat(1, diff), shapes[i]))
136139
common_shape = []
137140
for axis in range(biggest):
138141
lengths = [s[axis] for s in shapes]
@@ -928,6 +931,94 @@ def swapaxes(X, axis1, axis2):
928931
return dpt.permute_dims(X, tuple(ind))
929932

930933

934+
def repeat(x, repeats, axis=None):
935+
if not isinstance(x, dpt.usm_ndarray):
936+
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
937+
938+
x_ndim = x.ndim
939+
if axis is None:
940+
if x_ndim != 1:
941+
raise ValueError(
942+
f"`axis` cannot be `None` for array of dimension {x_ndim}"
943+
)
944+
axis = 0
945+
946+
x_shape = x.shape
947+
if x_ndim > 0:
948+
axis = normalize_axis_index(operator.index(axis), x_ndim)
949+
axis_size = x_shape[axis]
950+
else:
951+
if axis != 0:
952+
AxisError("`axis` must be `0` for input of dimension `0`")
953+
axis_size = x.size
954+
955+
scalar = False
956+
if isinstance(repeats, int):
957+
scalar = True
958+
if repeats < 0:
959+
raise ValueError("`repeats` must be a positive integer")
960+
elif isinstance(repeats, tuple):
961+
len_reps = len(repeats)
962+
if len_reps != axis_size:
963+
raise ValueError(
964+
"`repeats` tuple must have the same length as the repeated axis"
965+
)
966+
elif len_reps == 1:
967+
repeats = repeats[0]
968+
if repeats < 0:
969+
raise ValueError("`repeats` elements must be positive")
970+
scalar = True
971+
else:
972+
raise TypeError(
973+
f"Expected int or tuple for second argument, got {type(repeats)}"
974+
)
975+
976+
usm_type = x.usm_type
977+
exec_q = x.sycl_queue
978+
979+
if axis_size == 0:
980+
return dpt.empty(x_shape, dtype=x.dtype, sycl_queue=exec_q)
981+
982+
if scalar:
983+
res_axis_size = repeats * axis_size
984+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
985+
res = dpt.empty(
986+
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
987+
)
988+
if res_axis_size > 0:
989+
hev, _ = ti._repeat_by_scalar(
990+
src=x,
991+
dst=res,
992+
reps=repeats,
993+
axis=axis,
994+
sycl_queue=exec_q,
995+
)
996+
hev.wait()
997+
else:
998+
repeats = dpt.asarray(repeats, dtype="i8", sycl_queue=exec_q)
999+
if not dpt.all(repeats >= 0):
1000+
raise ValueError("`repeats` elements must be positive")
1001+
cumsum = dpt.empty(
1002+
(axis_size,), dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
1003+
)
1004+
res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q)
1005+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1006+
res = dpt.empty(
1007+
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1008+
)
1009+
if res_axis_size > 0:
1010+
hev, _ = ti._repeat_by_sequence(
1011+
src=x,
1012+
dst=res,
1013+
reps=repeats,
1014+
cumsum=cumsum,
1015+
axis=axis,
1016+
sycl_queue=exec_q,
1017+
)
1018+
hev.wait()
1019+
return res
1020+
1021+
9311022
def _supported_dtype(dtypes):
9321023
for dtype in dtypes:
9331024
if dtype.char not in "?bBhHiIlLqQefdFD":

0 commit comments

Comments
 (0)