Skip to content

Commit ca0ff64

Browse files
Added Python API for _prod_over_axis
1 parent 0598416 commit ca0ff64

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

dpctl/tensor/libtensor/source/reduction_over_axis.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,59 @@ void populate_sum_over_axis_dispatch_tables(void)
204204

205205
} // namespace impl
206206

207+
// Product
208+
namespace impl
209+
{
210+
211+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
212+
static reduction_strided_impl_fn_ptr
213+
prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types]
214+
[td_ns::num_types];
215+
static reduction_strided_impl_fn_ptr
216+
prod_over_axis_strided_temps_dispatch_table[td_ns::num_types]
217+
[td_ns::num_types];
218+
219+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
220+
static reduction_contig_impl_fn_ptr
221+
prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types]
222+
[td_ns::num_types];
223+
static reduction_contig_impl_fn_ptr
224+
prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types]
225+
[td_ns::num_types];
226+
227+
void populate_prod_over_axis_dispatch_tables(void)
228+
{
229+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
230+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
231+
using namespace td_ns;
232+
233+
using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory;
234+
DispatchTableBuilder<reduction_strided_impl_fn_ptr,
235+
ProductOverAxisAtomicStridedFactory, num_types>
236+
dtb1;
237+
dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table);
238+
239+
using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory;
240+
DispatchTableBuilder<reduction_strided_impl_fn_ptr,
241+
ProductOverAxisTempsStridedFactory, num_types>
242+
dtb2;
243+
dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table);
244+
245+
using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory;
246+
DispatchTableBuilder<reduction_contig_impl_fn_ptr,
247+
ProductOverAxis1AtomicContigFactory, num_types>
248+
dtb3;
249+
dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table);
250+
251+
using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory;
252+
DispatchTableBuilder<reduction_contig_impl_fn_ptr,
253+
ProductOverAxis0AtomicContigFactory, num_types>
254+
dtb4;
255+
dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table);
256+
}
257+
258+
} // namespace impl
259+
207260
// Argmax
208261
namespace impl
209262
{
@@ -350,6 +403,45 @@ void init_reduction_functions(py::module_ m)
350403
py::arg("dst_usm_type"), py::arg("sycl_queue"));
351404
}
352405

406+
// PROD
407+
{
408+
using dpctl::tensor::py_internal::impl::
409+
populate_prod_over_axis_dispatch_tables;
410+
populate_prod_over_axis_dispatch_tables();
411+
using impl::prod_over_axis0_contig_atomic_dispatch_table;
412+
using impl::prod_over_axis1_contig_atomic_dispatch_table;
413+
using impl::prod_over_axis_strided_atomic_dispatch_table;
414+
using impl::prod_over_axis_strided_temps_dispatch_table;
415+
416+
auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
417+
const arrayT &dst, sycl::queue &exec_q,
418+
const event_vecT &depends = {}) {
419+
using dpctl::tensor::py_internal::py_reduction_over_axis;
420+
return py_reduction_over_axis(
421+
src, trailing_dims_to_reduce, dst, exec_q, depends,
422+
prod_over_axis_strided_atomic_dispatch_table,
423+
prod_over_axis_strided_temps_dispatch_table,
424+
prod_over_axis0_contig_atomic_dispatch_table,
425+
prod_over_axis1_contig_atomic_dispatch_table);
426+
};
427+
m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"),
428+
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
429+
py::arg("sycl_queue"), py::arg("depends") = py::list());
430+
431+
auto prod_dtype_supported =
432+
[&](const py::dtype &input_dtype, const py::dtype &output_dtype,
433+
const std::string &dst_usm_type, sycl::queue &q) {
434+
using dpctl::tensor::py_internal::py_reduction_dtype_supported;
435+
return py_reduction_dtype_supported(
436+
input_dtype, output_dtype, dst_usm_type, q,
437+
prod_over_axis_strided_atomic_dispatch_table,
438+
prod_over_axis_strided_temps_dispatch_table);
439+
};
440+
m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "",
441+
py::arg("arg_dtype"), py::arg("out_dtype"),
442+
py::arg("dst_usm_type"), py::arg("sycl_queue"));
443+
}
444+
353445
// ARGMAX
354446
{
355447
using dpctl::tensor::py_internal::impl::

0 commit comments

Comments
 (0)