Skip to content

Commit 1b632b8

Browse files
committed
Implements flat overload for repeat
Adds tests for new functionality
1 parent caa0939 commit 1b632b8

File tree

6 files changed

+480
-82
lines changed

6 files changed

+480
-82
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import operator
2020

2121
import numpy as np
22-
from numpy import AxisError
2322
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
2423

2524
import dpctl
@@ -951,20 +950,11 @@ def repeat(x, repeats, axis=None):
951950
raise TypeError(f"Expected usm_ndarray type, got {type(x)}.")
952951

953952
x_ndim = x.ndim
954-
if axis is None:
955-
if x_ndim > 1:
956-
raise ValueError(
957-
f"`axis` cannot be `None` for array of dimension {x_ndim}"
958-
)
959-
axis = 0
960-
961953
x_shape = x.shape
962-
if x_ndim > 0:
954+
if axis is not None:
963955
axis = normalize_axis_index(operator.index(axis), x_ndim)
964956
axis_size = x_shape[axis]
965957
else:
966-
if axis != 0:
967-
AxisError("`axis` must be `0` for input of dimension `0`")
968958
axis_size = x.size
969959

970960
scalar = False
@@ -1047,7 +1037,10 @@ def repeat(x, repeats, axis=None):
10471037

10481038
if scalar:
10491039
res_axis_size = repeats * axis_size
1050-
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1040+
if axis is not None:
1041+
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1042+
else:
1043+
res_shape = (res_axis_size,)
10511044
res = dpt.empty(
10521045
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
10531046
)
@@ -1081,9 +1074,17 @@ def repeat(x, repeats, axis=None):
10811074
res_axis_size = ti._cumsum_1d(
10821075
rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
10831076
)
1084-
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1077+
if axis is not None:
1078+
res_shape = (
1079+
x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1080+
)
1081+
else:
1082+
res_shape = (res_axis_size,)
10851083
res = dpt.empty(
1086-
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1084+
res_shape,
1085+
dtype=x.dtype,
1086+
usm_type=usm_type,
1087+
sycl_queue=exec_q,
10871088
)
10881089
if res_axis_size > 0:
10891090
ht_rep_ev, _ = ti._repeat_by_sequence(
@@ -1103,11 +1104,18 @@ def repeat(x, repeats, axis=None):
11031104
usm_type=usm_type,
11041105
sycl_queue=exec_q,
11051106
)
1106-
# _cumsum_1d synchronizes so `depends` ends here safely
11071107
res_axis_size = ti._cumsum_1d(repeats, cumsum, sycl_queue=exec_q)
1108-
res_shape = x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1108+
if axis is not None:
1109+
res_shape = (
1110+
x_shape[:axis] + (res_axis_size,) + x_shape[axis + 1 :]
1111+
)
1112+
else:
1113+
res_shape = (res_axis_size,)
11091114
res = dpt.empty(
1110-
res_shape, dtype=x.dtype, usm_type=usm_type, sycl_queue=exec_q
1115+
res_shape,
1116+
dtype=x.dtype,
1117+
usm_type=usm_type,
1118+
sycl_queue=exec_q,
11111119
)
11121120
if res_axis_size > 0:
11131121
ht_rep_ev, _ = ti._repeat_by_sequence(

dpctl/tensor/libtensor/include/kernels/repeat.hpp

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ namespace py = pybind11;
4646
using namespace dpctl::tensor::offset_utils;
4747

4848
template <typename OrthogIndexer,
49-
typename AxisIndexer,
49+
typename SrcAxisIndexer,
50+
typename DstAxisIndexer,
5051
typename RepIndexer,
5152
typename T,
5253
typename repT>
5354
class repeat_by_sequence_kernel;
5455

5556
template <typename OrthogIndexer,
56-
typename AxisIndexer,
57+
typename SrcAxisIndexer,
58+
typename DstAxisIndexer,
5759
typename RepIndexer,
5860
typename T,
5961
typename repT>
@@ -66,8 +68,8 @@ class RepeatSequenceFunctor
6668
const repT *cumsum = nullptr;
6769
size_t src_axis_nelems = 1;
6870
OrthogIndexer orthog_strider;
69-
AxisIndexer src_axis_strider;
70-
AxisIndexer dst_axis_strider;
71+
SrcAxisIndexer src_axis_strider;
72+
DstAxisIndexer dst_axis_strider;
7173
RepIndexer reps_strider;
7274

7375
public:
@@ -77,8 +79,8 @@ class RepeatSequenceFunctor
7779
const repT *cumsum_,
7880
size_t src_axis_nelems_,
7981
OrthogIndexer orthog_strider_,
80-
AxisIndexer src_axis_strider_,
81-
AxisIndexer dst_axis_strider_,
82+
SrcAxisIndexer src_axis_strider_,
83+
DstAxisIndexer dst_axis_strider_,
8284
RepIndexer reps_strider_)
8385
: src(src_), dst(dst_), reps(reps_), cumsum(cumsum_),
8486
src_axis_nelems(src_axis_nelems_), orthog_strider(orthog_strider_),
@@ -167,12 +169,12 @@ repeat_by_sequence_impl(sycl::queue &q,
167169

168170
const size_t gws = orthog_nelems * src_axis_nelems;
169171

170-
cgh.parallel_for<repeat_by_sequence_kernel<TwoOffsets_StridedIndexer,
171-
Strided1DIndexer,
172-
Strided1DIndexer, T, repT>>(
172+
cgh.parallel_for<repeat_by_sequence_kernel<
173+
TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer,
174+
Strided1DIndexer, T, repT>>(
173175
sycl::range<1>(gws),
174176
RepeatSequenceFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
175-
Strided1DIndexer, T, repT>(
177+
Strided1DIndexer, Strided1DIndexer, T, repT>(
176178
src_tp, dst_tp, reps_tp, cumsum_tp, src_axis_nelems,
177179
orthog_indexer, src_axis_indexer, dst_axis_indexer,
178180
reps_indexer));
@@ -197,8 +199,8 @@ typedef sycl::event (*repeat_by_sequence_1d_fn_ptr_t)(
197199
char *,
198200
const char *,
199201
const char *,
200-
py::ssize_t,
201-
py::ssize_t,
202+
int,
203+
const py::ssize_t *,
202204
py::ssize_t,
203205
py::ssize_t,
204206
py::ssize_t,
@@ -212,8 +214,8 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
212214
char *dst_cp,
213215
const char *reps_cp,
214216
const char *cumsum_cp,
215-
py::ssize_t src_shape,
216-
py::ssize_t src_stride,
217+
int src_nd,
218+
const py::ssize_t *src_shape_strides,
217219
py::ssize_t dst_shape,
218220
py::ssize_t dst_stride,
219221
py::ssize_t reps_shape,
@@ -231,19 +233,19 @@ sycl::event repeat_by_sequence_1d_impl(sycl::queue &q,
231233
// orthog ndim indexer
232234
TwoZeroOffsets_Indexer orthog_indexer{};
233235
// indexers along repeated axis
234-
Strided1DIndexer src_indexer{0, src_shape, src_stride};
236+
StridedIndexer src_indexer{src_nd, 0, src_shape_strides};
235237
Strided1DIndexer dst_indexer{0, dst_shape, dst_stride};
236238
// indexer along reps array
237239
Strided1DIndexer reps_indexer{0, reps_shape, reps_stride};
238240

239241
const size_t gws = src_nelems;
240242

241-
cgh.parallel_for<
242-
repeat_by_sequence_kernel<TwoZeroOffsets_Indexer, Strided1DIndexer,
243-
Strided1DIndexer, T, repT>>(
243+
cgh.parallel_for<repeat_by_sequence_kernel<
244+
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer,
245+
Strided1DIndexer, T, repT>>(
244246
sycl::range<1>(gws),
245-
RepeatSequenceFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer,
246-
Strided1DIndexer, T, repT>(
247+
RepeatSequenceFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
248+
Strided1DIndexer, Strided1DIndexer, T, repT>(
247249
src_tp, dst_tp, reps_tp, cumsum_tp, src_nelems, orthog_indexer,
248250
src_indexer, dst_indexer, reps_indexer));
249251
});
@@ -260,10 +262,16 @@ template <typename fnT, typename T> struct RepeatSequence1DFactory
260262
}
261263
};
262264

263-
template <typename OrthogIndexer, typename AxisIndexer, typename T>
265+
template <typename OrthogIndexer,
266+
typename SrcAxisIndexer,
267+
typename DstAxisIndexer,
268+
typename T>
264269
class repeat_by_scalar_kernel;
265270

266-
template <typename OrthogIndexer, typename AxisIndexer, typename T>
271+
template <typename OrthogIndexer,
272+
typename SrcAxisIndexer,
273+
typename DstAxisIndexer,
274+
typename T>
267275
class RepeatScalarFunctor
268276
{
269277
private:
@@ -272,17 +280,17 @@ class RepeatScalarFunctor
272280
const py::ssize_t reps = 1;
273281
size_t dst_axis_nelems = 0;
274282
OrthogIndexer orthog_strider;
275-
AxisIndexer src_axis_strider;
276-
AxisIndexer dst_axis_strider;
283+
SrcAxisIndexer src_axis_strider;
284+
DstAxisIndexer dst_axis_strider;
277285

278286
public:
279287
RepeatScalarFunctor(const T *src_,
280288
T *dst_,
281289
const py::ssize_t reps_,
282290
size_t dst_axis_nelems_,
283291
OrthogIndexer orthog_strider_,
284-
AxisIndexer src_axis_strider_,
285-
AxisIndexer dst_axis_strider_)
292+
SrcAxisIndexer src_axis_strider_,
293+
DstAxisIndexer dst_axis_strider_)
286294
: src(src_), dst(dst_), reps(reps_), dst_axis_nelems(dst_axis_nelems_),
287295
orthog_strider(orthog_strider_), src_axis_strider(src_axis_strider_),
288296
dst_axis_strider(dst_axis_strider_)
@@ -354,10 +362,11 @@ sycl::event repeat_by_scalar_impl(sycl::queue &q,
354362

355363
const size_t gws = orthog_nelems * dst_axis_nelems;
356364

357-
cgh.parallel_for<repeat_by_scalar_kernel<TwoOffsets_StridedIndexer,
358-
Strided1DIndexer, T>>(
365+
cgh.parallel_for<repeat_by_scalar_kernel<
366+
TwoOffsets_StridedIndexer, Strided1DIndexer, Strided1DIndexer, T>>(
359367
sycl::range<1>(gws),
360-
RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer, T>(
368+
RepeatScalarFunctor<TwoOffsets_StridedIndexer, Strided1DIndexer,
369+
Strided1DIndexer, T>(
361370
src_tp, dst_tp, reps, dst_axis_nelems, orthog_indexer,
362371
src_axis_indexer, dst_axis_indexer));
363372
});
@@ -380,8 +389,8 @@ typedef sycl::event (*repeat_by_scalar_1d_fn_ptr_t)(
380389
const char *,
381390
char *,
382391
const py::ssize_t,
383-
py::ssize_t,
384-
py::ssize_t,
392+
int,
393+
const py::ssize_t *,
385394
py::ssize_t,
386395
py::ssize_t,
387396
const std::vector<sycl::event> &);
@@ -392,8 +401,8 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
392401
const char *src_cp,
393402
char *dst_cp,
394403
const py::ssize_t reps,
395-
py::ssize_t src_shape,
396-
py::ssize_t src_stride,
404+
int src_nd,
405+
const py::ssize_t *src_shape_strides,
397406
py::ssize_t dst_shape,
398407
py::ssize_t dst_stride,
399408
const std::vector<sycl::event> &depends)
@@ -407,17 +416,18 @@ sycl::event repeat_by_scalar_1d_impl(sycl::queue &q,
407416
// orthog ndim indexer
408417
TwoZeroOffsets_Indexer orthog_indexer{};
409418
// indexers along repeated axis
410-
Strided1DIndexer src_indexer(0, src_shape, src_stride);
419+
StridedIndexer src_indexer(src_nd, 0, src_shape_strides);
411420
Strided1DIndexer dst_indexer{0, dst_shape, dst_stride};
412421

413422
const size_t gws = dst_nelems;
414423

415-
cgh.parallel_for<repeat_by_scalar_kernel<TwoZeroOffsets_Indexer,
416-
Strided1DIndexer, T>>(
424+
cgh.parallel_for<repeat_by_scalar_kernel<
425+
TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, T>>(
417426
sycl::range<1>(gws),
418-
RepeatScalarFunctor<TwoZeroOffsets_Indexer, Strided1DIndexer, T>(
419-
src_tp, dst_tp, reps, dst_nelems, orthog_indexer, src_indexer,
420-
dst_indexer));
427+
RepeatScalarFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
428+
Strided1DIndexer, T>(src_tp, dst_tp, reps,
429+
dst_nelems, orthog_indexer,
430+
src_indexer, dst_indexer));
421431
});
422432

423433
return repeat_ev;

0 commit comments

Comments
 (0)