@@ -204,6 +204,59 @@ void populate_sum_over_axis_dispatch_tables(void)
204
204
205
205
} // namespace impl
206
206
207
+ // Product
208
+ namespace impl
209
+ {
210
+
211
+ using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
212
+ static reduction_strided_impl_fn_ptr
213
+ prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types]
214
+ [td_ns::num_types];
215
+ static reduction_strided_impl_fn_ptr
216
+ prod_over_axis_strided_temps_dispatch_table[td_ns::num_types]
217
+ [td_ns::num_types];
218
+
219
+ using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
220
+ static reduction_contig_impl_fn_ptr
221
+ prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types]
222
+ [td_ns::num_types];
223
+ static reduction_contig_impl_fn_ptr
224
+ prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types]
225
+ [td_ns::num_types];
226
+
227
+ void populate_prod_over_axis_dispatch_tables (void )
228
+ {
229
+ using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
230
+ using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
231
+ using namespace td_ns ;
232
+
233
+ using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory;
234
+ DispatchTableBuilder<reduction_strided_impl_fn_ptr,
235
+ ProductOverAxisAtomicStridedFactory, num_types>
236
+ dtb1;
237
+ dtb1.populate_dispatch_table (prod_over_axis_strided_atomic_dispatch_table);
238
+
239
+ using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory;
240
+ DispatchTableBuilder<reduction_strided_impl_fn_ptr,
241
+ ProductOverAxisTempsStridedFactory, num_types>
242
+ dtb2;
243
+ dtb2.populate_dispatch_table (prod_over_axis_strided_temps_dispatch_table);
244
+
245
+ using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory;
246
+ DispatchTableBuilder<reduction_contig_impl_fn_ptr,
247
+ ProductOverAxis1AtomicContigFactory, num_types>
248
+ dtb3;
249
+ dtb3.populate_dispatch_table (prod_over_axis1_contig_atomic_dispatch_table);
250
+
251
+ using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory;
252
+ DispatchTableBuilder<reduction_contig_impl_fn_ptr,
253
+ ProductOverAxis0AtomicContigFactory, num_types>
254
+ dtb4;
255
+ dtb4.populate_dispatch_table (prod_over_axis0_contig_atomic_dispatch_table);
256
+ }
257
+
258
+ } // namespace impl
259
+
207
260
// Argmax
208
261
namespace impl
209
262
{
@@ -350,6 +403,45 @@ void init_reduction_functions(py::module_ m)
350
403
py::arg (" dst_usm_type" ), py::arg (" sycl_queue" ));
351
404
}
352
405
406
+ // PROD
407
+ {
408
+ using dpctl::tensor::py_internal::impl::
409
+ populate_prod_over_axis_dispatch_tables;
410
+ populate_prod_over_axis_dispatch_tables ();
411
+ using impl::prod_over_axis0_contig_atomic_dispatch_table;
412
+ using impl::prod_over_axis1_contig_atomic_dispatch_table;
413
+ using impl::prod_over_axis_strided_atomic_dispatch_table;
414
+ using impl::prod_over_axis_strided_temps_dispatch_table;
415
+
416
+ auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
417
+ const arrayT &dst, sycl::queue &exec_q,
418
+ const event_vecT &depends = {}) {
419
+ using dpctl::tensor::py_internal::py_reduction_over_axis;
420
+ return py_reduction_over_axis (
421
+ src, trailing_dims_to_reduce, dst, exec_q, depends,
422
+ prod_over_axis_strided_atomic_dispatch_table,
423
+ prod_over_axis_strided_temps_dispatch_table,
424
+ prod_over_axis0_contig_atomic_dispatch_table,
425
+ prod_over_axis1_contig_atomic_dispatch_table);
426
+ };
427
+ m.def (" _prod_over_axis" , prod_pyapi, " " , py::arg (" src" ),
428
+ py::arg (" trailing_dims_to_reduce" ), py::arg (" dst" ),
429
+ py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
430
+
431
+ auto prod_dtype_supported =
432
+ [&](const py::dtype &input_dtype, const py::dtype &output_dtype,
433
+ const std::string &dst_usm_type, sycl::queue &q) {
434
+ using dpctl::tensor::py_internal::py_reduction_dtype_supported;
435
+ return py_reduction_dtype_supported (
436
+ input_dtype, output_dtype, dst_usm_type, q,
437
+ prod_over_axis_strided_atomic_dispatch_table,
438
+ prod_over_axis_strided_temps_dispatch_table);
439
+ };
440
+ m.def (" _prod_over_axis_dtype_supported" , prod_dtype_supported, " " ,
441
+ py::arg (" arg_dtype" ), py::arg (" out_dtype" ),
442
+ py::arg (" dst_usm_type" ), py::arg (" sycl_queue" ));
443
+ }
444
+
353
445
// ARGMAX
354
446
{
355
447
using dpctl::tensor::py_internal::impl::
0 commit comments