Skip to content

Commit de79b20

Browse files
committed
Implements _cumsum_1d
1 parent 68a652e commit de79b20

File tree

4 files changed

+226
-36
lines changed

4 files changed

+226
-36
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -213,29 +213,27 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
213213
return out_event;
214214
}
215215

216-
// mask positions
217-
218-
typedef size_t (*mask_positions_contig_impl_fn_ptr_t)(
216+
typedef size_t (*accumulate_contig_impl_fn_ptr_t)(
219217
sycl::queue,
220218
size_t,
221219
const char *,
222220
char *,
223221
std::vector<sycl::event> const &);
224222

225-
template <typename maskT, typename cumsumT>
226-
size_t mask_positions_contig_impl(sycl::queue q,
227-
size_t n_elems,
228-
const char *mask,
229-
char *cumsum,
230-
std::vector<sycl::event> const &depends = {})
223+
template <typename maskT, typename cumsumT, typename transformerT>
224+
size_t accumulate_contig_impl(sycl::queue q,
225+
size_t n_elems,
226+
const char *mask,
227+
char *cumsum,
228+
std::vector<sycl::event> const &depends = {})
231229
{
232230
constexpr int n_wi = 8;
233231
const maskT *mask_data_ptr = reinterpret_cast<const maskT *>(mask);
234232
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
235233
size_t wg_size = 128;
236234

237235
NoOpIndexer flat_indexer{};
238-
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
236+
transformerT non_zero_indicator{};
239237

240238
sycl::event comp_ev =
241239
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(flat_indexer),
@@ -263,7 +261,9 @@ template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt32
263261
{
264262
fnT get()
265263
{
266-
fnT fn = mask_positions_contig_impl<T, std::int32_t>;
264+
using cumsumT = std::int32_t;
265+
fnT fn =
266+
accumulate_contig_impl<T, cumsumT, NonZeroIndicator<T, cumsumT>>;
267267
return fn;
268268
}
269269
};
@@ -272,12 +272,30 @@ template <typename fnT, typename T> struct MaskPositionsContigFactoryForInt64
272272
{
273273
fnT get()
274274
{
275-
fnT fn = mask_positions_contig_impl<T, std::int64_t>;
275+
using cumsumT = std::int64_t;
276+
fnT fn =
277+
accumulate_contig_impl<T, cumsumT, NonZeroIndicator<T, cumsumT>>;
276278
return fn;
277279
}
278280
};
279281

280-
typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
282+
template <typename fnT, typename T> struct Cumsum1DContigFactory
283+
{
284+
fnT get()
285+
{
286+
if constexpr (std::is_integral_v<T>) {
287+
using cumsumT = std::int64_t;
288+
fnT fn =
289+
accumulate_contig_impl<T, cumsumT, NoOpTransformer<cumsumT>>;
290+
return fn;
291+
}
292+
else {
293+
return nullptr;
294+
}
295+
}
296+
};
297+
298+
typedef size_t (*accumulate_strided_impl_fn_ptr_t)(
281299
sycl::queue,
282300
size_t,
283301
const char *,
@@ -286,22 +304,22 @@ typedef size_t (*mask_positions_strided_impl_fn_ptr_t)(
286304
char *,
287305
std::vector<sycl::event> const &);
288306

289-
template <typename maskT, typename cumsumT>
290-
size_t mask_positions_strided_impl(sycl::queue q,
291-
size_t n_elems,
292-
const char *mask,
293-
int nd,
294-
const py::ssize_t *shape_strides,
295-
char *cumsum,
296-
std::vector<sycl::event> const &depends = {})
307+
template <typename maskT, typename cumsumT, typename transformerT>
308+
size_t accumulate_strided_impl(sycl::queue q,
309+
size_t n_elems,
310+
const char *mask,
311+
int nd,
312+
const py::ssize_t *shape_strides,
313+
char *cumsum,
314+
std::vector<sycl::event> const &depends = {})
297315
{
298316
constexpr int n_wi = 8;
299317
const maskT *mask_data_ptr = reinterpret_cast<const maskT *>(mask);
300318
cumsumT *cumsum_data_ptr = reinterpret_cast<cumsumT *>(cumsum);
301319
size_t wg_size = 128;
302320

303321
StridedIndexer strided_indexer{nd, 0, shape_strides};
304-
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
322+
transformerT non_zero_indicator{};
305323

306324
sycl::event comp_ev =
307325
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(strided_indexer),
@@ -329,7 +347,9 @@ template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt32
329347
{
330348
fnT get()
331349
{
332-
fnT fn = mask_positions_strided_impl<T, std::int32_t>;
350+
using cumsumT = std::int32_t;
351+
fnT fn =
352+
accumulate_strided_impl<T, cumsumT, NonZeroIndicator<T, cumsumT>>;
333353
return fn;
334354
}
335355
};
@@ -338,11 +358,29 @@ template <typename fnT, typename T> struct MaskPositionsStridedFactoryForInt64
338358
{
339359
fnT get()
340360
{
341-
fnT fn = mask_positions_strided_impl<T, std::int64_t>;
361+
using cumsumT = std::int64_t;
362+
fnT fn =
363+
accumulate_strided_impl<T, cumsumT, NonZeroIndicator<T, cumsumT>>;
342364
return fn;
343365
}
344366
};
345367

368+
template <typename fnT, typename T> struct Cumsum1DStridedFactory
369+
{
370+
fnT get()
371+
{
372+
if constexpr (std::is_integral_v<T>) {
373+
using cumsumT = std::int64_t;
374+
fnT fn =
375+
accumulate_strided_impl<T, cumsumT, NoOpTransformer<cumsumT>>;
376+
return fn;
377+
}
378+
else {
379+
return nullptr;
380+
}
381+
}
382+
};
383+
346384
} // namespace accumulators
347385
} // namespace kernels
348386
} // namespace tensor

dpctl/tensor/libtensor/source/accumulators.cpp

Lines changed: 148 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,48 +48,47 @@ namespace py_internal
4848

4949
namespace td_ns = dpctl::tensor::type_dispatch;
5050

51-
using dpctl::tensor::kernels::accumulators::mask_positions_contig_impl_fn_ptr_t;
52-
static mask_positions_contig_impl_fn_ptr_t
51+
using dpctl::tensor::kernels::accumulators::accumulate_contig_impl_fn_ptr_t;
52+
static accumulate_contig_impl_fn_ptr_t
5353
mask_positions_contig_i64_dispatch_vector[td_ns::num_types];
54-
static mask_positions_contig_impl_fn_ptr_t
54+
static accumulate_contig_impl_fn_ptr_t
5555
mask_positions_contig_i32_dispatch_vector[td_ns::num_types];
5656

57-
using dpctl::tensor::kernels::accumulators::
58-
mask_positions_strided_impl_fn_ptr_t;
59-
static mask_positions_strided_impl_fn_ptr_t
57+
using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t;
58+
static accumulate_strided_impl_fn_ptr_t
6059
mask_positions_strided_i64_dispatch_vector[td_ns::num_types];
61-
static mask_positions_strided_impl_fn_ptr_t
60+
static accumulate_strided_impl_fn_ptr_t
6261
mask_positions_strided_i32_dispatch_vector[td_ns::num_types];
6362

6463
void populate_mask_positions_dispatch_vectors(void)
6564
{
6665
using dpctl::tensor::kernels::accumulators::
6766
MaskPositionsContigFactoryForInt64;
68-
td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t,
67+
td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t,
6968
MaskPositionsContigFactoryForInt64,
7069
td_ns::num_types>
7170
dvb1;
7271
dvb1.populate_dispatch_vector(mask_positions_contig_i64_dispatch_vector);
7372

7473
using dpctl::tensor::kernels::accumulators::
7574
MaskPositionsContigFactoryForInt32;
76-
td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t,
75+
td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t,
7776
MaskPositionsContigFactoryForInt32,
7877
td_ns::num_types>
7978
dvb2;
8079
dvb2.populate_dispatch_vector(mask_positions_contig_i32_dispatch_vector);
8180

8281
using dpctl::tensor::kernels::accumulators::
8382
MaskPositionsStridedFactoryForInt64;
84-
td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t,
83+
td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t,
8584
MaskPositionsStridedFactoryForInt64,
8685
td_ns::num_types>
8786
dvb3;
8887
dvb3.populate_dispatch_vector(mask_positions_strided_i64_dispatch_vector);
8988

9089
using dpctl::tensor::kernels::accumulators::
9190
MaskPositionsStridedFactoryForInt32;
92-
td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t,
91+
td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t,
9392
MaskPositionsStridedFactoryForInt32,
9493
td_ns::num_types>
9594
dvb4;
@@ -210,6 +209,144 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
210209
return total_set;
211210
}
212211

212+
using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t;
213+
static accumulate_strided_impl_fn_ptr_t
214+
cumsum_1d_strided_dispatch_vector[td_ns::num_types];
215+
using dpctl::tensor::kernels::accumulators::accumulate_contig_impl_fn_ptr_t;
216+
static accumulate_contig_impl_fn_ptr_t
217+
cumsum_1d_contig_dispatch_vector[td_ns::num_types];
218+
219+
void populate_cumsum_1d_dispatch_vectors(void)
220+
{
221+
using dpctl::tensor::kernels::accumulators::Cumsum1DContigFactory;
222+
td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t,
223+
Cumsum1DContigFactory, td_ns::num_types>
224+
dvb1;
225+
dvb1.populate_dispatch_vector(cumsum_1d_contig_dispatch_vector);
226+
227+
using dpctl::tensor::kernels::accumulators::Cumsum1DStridedFactory;
228+
td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t,
229+
Cumsum1DStridedFactory, td_ns::num_types>
230+
dvb2;
231+
dvb2.populate_dispatch_vector(cumsum_1d_strided_dispatch_vector);
232+
233+
return;
234+
}
235+
236+
size_t py_cumsum_1d(dpctl::tensor::usm_ndarray src,
237+
dpctl::tensor::usm_ndarray cumsum,
238+
sycl::queue exec_q,
239+
std::vector<sycl::event> const &depends)
240+
{
241+
// cumsum is 1D
242+
if (cumsum.get_ndim() != 1) {
243+
throw py::value_error("cumsum array must be one-dimensional.");
244+
}
245+
246+
if (!cumsum.is_c_contiguous()) {
247+
throw py::value_error("Expecting `cumsum` array to be C-contiguous.");
248+
}
249+
250+
// cumsum.shape == (src.size,)
251+
auto src_size = src.get_size();
252+
auto cumsum_size = cumsum.get_shape(0);
253+
if (cumsum_size != src_size) {
254+
throw py::value_error("Inconsistent dimensions");
255+
}
256+
257+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, cumsum})) {
258+
// FIXME: use ExecutionPlacementError
259+
throw py::value_error(
260+
"Execution queue is not compatible with allocation queues");
261+
}
262+
263+
if (src_size == 0) {
264+
return 0;
265+
}
266+
267+
int src_typenum = src.get_typenum();
268+
int cumsum_typenum = cumsum.get_typenum();
269+
270+
// src can be any type
271+
const char *src_data = src.get_data();
272+
char *cumsum_data = cumsum.get_data();
273+
274+
auto const &array_types = td_ns::usm_ndarray_types();
275+
276+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
277+
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);
278+
279+
// this cumsum must be int64_t only
280+
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
281+
if (cumsum_typeid != int64_typeid) {
282+
throw py::value_error(
283+
"Cumulative sum array must have int64 data-type.");
284+
}
285+
286+
if (src.is_c_contiguous()) {
287+
auto fn = cumsum_1d_contig_dispatch_vector[src_typeid];
288+
if (fn == nullptr) {
289+
throw std::runtime_error(
290+
"this cumsum requires integer type, got src_typeid=" +
291+
std::to_string(src_typeid));
292+
}
293+
return fn(exec_q, src_size, src_data, cumsum_data, depends);
294+
}
295+
296+
const py::ssize_t *shape = src.get_shape_raw();
297+
auto const &strides_vector = src.get_strides_vector();
298+
299+
using shT = std::vector<py::ssize_t>;
300+
shT compact_shape;
301+
shT compact_strides;
302+
303+
int src_nd = src.get_ndim();
304+
int nd = src_nd;
305+
306+
dpctl::tensor::py_internal::compact_iteration_space(
307+
nd, shape, strides_vector, compact_shape, compact_strides);
308+
309+
// Strided implementation
310+
auto strided_fn = cumsum_1d_strided_dispatch_vector[src_typeid];
311+
if (strided_fn == nullptr) {
312+
throw std::runtime_error(
313+
"this cumsum requires integer type, got src_typeid=" +
314+
std::to_string(src_typeid));
315+
}
316+
std::vector<sycl::event> host_task_events;
317+
318+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
319+
const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t>(
320+
exec_q, host_task_events, compact_shape, compact_strides);
321+
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_tuple);
322+
if (shape_strides == nullptr) {
323+
sycl::event::wait(host_task_events);
324+
throw std::runtime_error("Unexpected error");
325+
}
326+
sycl::event copy_shape_ev = std::get<2>(ptr_size_event_tuple);
327+
328+
if (2 * static_cast<size_t>(nd) != std::get<1>(ptr_size_event_tuple)) {
329+
copy_shape_ev.wait();
330+
sycl::event::wait(host_task_events);
331+
sycl::free(shape_strides, exec_q);
332+
throw std::runtime_error("Unexpected error");
333+
}
334+
335+
std::vector<sycl::event> dependent_events;
336+
dependent_events.reserve(depends.size() + 1);
337+
dependent_events.insert(dependent_events.end(), copy_shape_ev);
338+
dependent_events.insert(dependent_events.end(), depends.begin(),
339+
depends.end());
340+
341+
size_t total = strided_fn(exec_q, src_size, src_data, nd, shape_strides,
342+
cumsum_data, dependent_events);
343+
344+
sycl::event::wait(host_task_events);
345+
sycl::free(shape_strides, exec_q);
346+
347+
return total;
348+
}
349+
213350
} // namespace py_internal
214351
} // namespace tensor
215352
} // namespace dpctl

dpctl/tensor/libtensor/source/accumulators.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ extern size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
4444
sycl::queue exec_q,
4545
std::vector<sycl::event> const &depends = {});
4646

47+
extern void populate_cumsum_1d_dispatch_vectors(void);
48+
49+
extern size_t py_cumsum_1d(dpctl::tensor::usm_ndarray src,
50+
dpctl::tensor::usm_ndarray cumsum,
51+
sycl::queue exec_q,
52+
std::vector<sycl::event> const &depends = {});
53+
4754
} // namespace py_internal
4855
} // namespace tensor
4956
} // namespace dpctl

0 commit comments

Comments
 (0)