@@ -48,48 +48,47 @@ namespace py_internal
48
48
49
49
namespace td_ns = dpctl::tensor::type_dispatch;
50
50
51
- using dpctl::tensor::kernels::accumulators::mask_positions_contig_impl_fn_ptr_t ;
52
- static mask_positions_contig_impl_fn_ptr_t
51
+ using dpctl::tensor::kernels::accumulators::accumulate_contig_impl_fn_ptr_t ;
52
+ static accumulate_contig_impl_fn_ptr_t
53
53
mask_positions_contig_i64_dispatch_vector[td_ns::num_types];
54
- static mask_positions_contig_impl_fn_ptr_t
54
+ static accumulate_contig_impl_fn_ptr_t
55
55
mask_positions_contig_i32_dispatch_vector[td_ns::num_types];
56
56
57
- using dpctl::tensor::kernels::accumulators::
58
- mask_positions_strided_impl_fn_ptr_t ;
59
- static mask_positions_strided_impl_fn_ptr_t
57
+ using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t ;
58
+ static accumulate_strided_impl_fn_ptr_t
60
59
mask_positions_strided_i64_dispatch_vector[td_ns::num_types];
61
- static mask_positions_strided_impl_fn_ptr_t
60
+ static accumulate_strided_impl_fn_ptr_t
62
61
mask_positions_strided_i32_dispatch_vector[td_ns::num_types];
63
62
64
63
void populate_mask_positions_dispatch_vectors (void )
65
64
{
66
65
using dpctl::tensor::kernels::accumulators::
67
66
MaskPositionsContigFactoryForInt64;
68
- td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t ,
67
+ td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t ,
69
68
MaskPositionsContigFactoryForInt64,
70
69
td_ns::num_types>
71
70
dvb1;
72
71
dvb1.populate_dispatch_vector (mask_positions_contig_i64_dispatch_vector);
73
72
74
73
using dpctl::tensor::kernels::accumulators::
75
74
MaskPositionsContigFactoryForInt32;
76
- td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t ,
75
+ td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t ,
77
76
MaskPositionsContigFactoryForInt32,
78
77
td_ns::num_types>
79
78
dvb2;
80
79
dvb2.populate_dispatch_vector (mask_positions_contig_i32_dispatch_vector);
81
80
82
81
using dpctl::tensor::kernels::accumulators::
83
82
MaskPositionsStridedFactoryForInt64;
84
- td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t ,
83
+ td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t ,
85
84
MaskPositionsStridedFactoryForInt64,
86
85
td_ns::num_types>
87
86
dvb3;
88
87
dvb3.populate_dispatch_vector (mask_positions_strided_i64_dispatch_vector);
89
88
90
89
using dpctl::tensor::kernels::accumulators::
91
90
MaskPositionsStridedFactoryForInt32;
92
- td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t ,
91
+ td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t ,
93
92
MaskPositionsStridedFactoryForInt32,
94
93
td_ns::num_types>
95
94
dvb4;
@@ -210,6 +209,144 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
210
209
return total_set;
211
210
}
212
211
212
+ using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t ;
213
+ static accumulate_strided_impl_fn_ptr_t
214
+ cumsum_1d_strided_dispatch_vector[td_ns::num_types];
215
+ using dpctl::tensor::kernels::accumulators::accumulate_contig_impl_fn_ptr_t ;
216
+ static accumulate_contig_impl_fn_ptr_t
217
+ cumsum_1d_contig_dispatch_vector[td_ns::num_types];
218
+
219
+ void populate_cumsum_1d_dispatch_vectors (void )
220
+ {
221
+ using dpctl::tensor::kernels::accumulators::Cumsum1DContigFactory;
222
+ td_ns::DispatchVectorBuilder<accumulate_contig_impl_fn_ptr_t ,
223
+ Cumsum1DContigFactory, td_ns::num_types>
224
+ dvb1;
225
+ dvb1.populate_dispatch_vector (cumsum_1d_contig_dispatch_vector);
226
+
227
+ using dpctl::tensor::kernels::accumulators::Cumsum1DStridedFactory;
228
+ td_ns::DispatchVectorBuilder<accumulate_strided_impl_fn_ptr_t ,
229
+ Cumsum1DStridedFactory, td_ns::num_types>
230
+ dvb2;
231
+ dvb2.populate_dispatch_vector (cumsum_1d_strided_dispatch_vector);
232
+
233
+ return ;
234
+ }
235
+
236
+ size_t py_cumsum_1d (dpctl::tensor::usm_ndarray src,
237
+ dpctl::tensor::usm_ndarray cumsum,
238
+ sycl::queue exec_q,
239
+ std::vector<sycl::event> const &depends)
240
+ {
241
+ // cumsum is 1D
242
+ if (cumsum.get_ndim () != 1 ) {
243
+ throw py::value_error (" cumsum array must be one-dimensional." );
244
+ }
245
+
246
+ if (!cumsum.is_c_contiguous ()) {
247
+ throw py::value_error (" Expecting `cumsum` array to be C-contiguous." );
248
+ }
249
+
250
+ // cumsum.shape == (src.size,)
251
+ auto src_size = src.get_size ();
252
+ auto cumsum_size = cumsum.get_shape (0 );
253
+ if (cumsum_size != src_size) {
254
+ throw py::value_error (" Inconsistent dimensions" );
255
+ }
256
+
257
+ if (!dpctl::utils::queues_are_compatible (exec_q, {src, cumsum})) {
258
+ // FIXME: use ExecutionPlacementError
259
+ throw py::value_error (
260
+ " Execution queue is not compatible with allocation queues" );
261
+ }
262
+
263
+ if (src_size == 0 ) {
264
+ return 0 ;
265
+ }
266
+
267
+ int src_typenum = src.get_typenum ();
268
+ int cumsum_typenum = cumsum.get_typenum ();
269
+
270
+ // src can be any type
271
+ const char *src_data = src.get_data ();
272
+ char *cumsum_data = cumsum.get_data ();
273
+
274
+ auto const &array_types = td_ns::usm_ndarray_types ();
275
+
276
+ int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
277
+ int cumsum_typeid = array_types.typenum_to_lookup_id (cumsum_typenum);
278
+
279
+ // this cumsum must be int64_t only
280
+ constexpr int int64_typeid = static_cast <int >(td_ns::typenum_t ::INT64);
281
+ if (cumsum_typeid != int64_typeid) {
282
+ throw py::value_error (
283
+ " Cumulative sum array must have int64 data-type." );
284
+ }
285
+
286
+ if (src.is_c_contiguous ()) {
287
+ auto fn = cumsum_1d_contig_dispatch_vector[src_typeid];
288
+ if (fn == nullptr ) {
289
+ throw std::runtime_error (
290
+ " this cumsum requires integer type, got src_typeid=" +
291
+ std::to_string (src_typeid));
292
+ }
293
+ return fn (exec_q, src_size, src_data, cumsum_data, depends);
294
+ }
295
+
296
+ const py::ssize_t *shape = src.get_shape_raw ();
297
+ auto const &strides_vector = src.get_strides_vector ();
298
+
299
+ using shT = std::vector<py::ssize_t >;
300
+ shT compact_shape;
301
+ shT compact_strides;
302
+
303
+ int src_nd = src.get_ndim ();
304
+ int nd = src_nd;
305
+
306
+ dpctl::tensor::py_internal::compact_iteration_space (
307
+ nd, shape, strides_vector, compact_shape, compact_strides);
308
+
309
+ // Strided implementation
310
+ auto strided_fn = cumsum_1d_strided_dispatch_vector[src_typeid];
311
+ if (strided_fn == nullptr ) {
312
+ throw std::runtime_error (
313
+ " this cumsum requires integer type, got src_typeid=" +
314
+ std::to_string (src_typeid));
315
+ }
316
+ std::vector<sycl::event> host_task_events;
317
+
318
+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
319
+ const auto &ptr_size_event_tuple = device_allocate_and_pack<py::ssize_t >(
320
+ exec_q, host_task_events, compact_shape, compact_strides);
321
+ py::ssize_t *shape_strides = std::get<0 >(ptr_size_event_tuple);
322
+ if (shape_strides == nullptr ) {
323
+ sycl::event::wait (host_task_events);
324
+ throw std::runtime_error (" Unexpected error" );
325
+ }
326
+ sycl::event copy_shape_ev = std::get<2 >(ptr_size_event_tuple);
327
+
328
+ if (2 * static_cast <size_t >(nd) != std::get<1 >(ptr_size_event_tuple)) {
329
+ copy_shape_ev.wait ();
330
+ sycl::event::wait (host_task_events);
331
+ sycl::free (shape_strides, exec_q);
332
+ throw std::runtime_error (" Unexpected error" );
333
+ }
334
+
335
+ std::vector<sycl::event> dependent_events;
336
+ dependent_events.reserve (depends.size () + 1 );
337
+ dependent_events.insert (dependent_events.end (), copy_shape_ev);
338
+ dependent_events.insert (dependent_events.end (), depends.begin (),
339
+ depends.end ());
340
+
341
+ size_t total = strided_fn (exec_q, src_size, src_data, nd, shape_strides,
342
+ cumsum_data, dependent_events);
343
+
344
+ sycl::event::wait (host_task_events);
345
+ sycl::free (shape_strides, exec_q);
346
+
347
+ return total;
348
+ }
349
+
213
350
} // namespace py_internal
214
351
} // namespace tensor
215
352
} // namespace dpctl
0 commit comments