Skip to content

Commit a3d04ba

Browse files
Add pow call from OneMKL by pybind11 extension
1 parent 96b9759 commit a3d04ba

File tree

4 files changed

+156
-2
lines changed

4 files changed

+156
-2
lines changed

dpnp/backend/extensions/vm/pow.hpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event pow_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::pow(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct PowContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::PowOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return pow_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,31 @@ struct LnOutputType
106106
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
107107
};
108108

109+
/**
110+
* @brief A factory to define pairs of supported types for which
111+
* MKL VM library provides support in oneapi::mkl::vm::pow<T> function.
112+
*
113+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
114+
*/
115+
template <typename T>
116+
struct PowOutputType
117+
{
118+
using value_type = typename std::disjunction<
119+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
120+
std::complex<double>,
121+
T,
122+
std::complex<double>,
123+
std::complex<double>>,
124+
dpctl_td_ns::BinaryTypeMapResultEntry<T,
125+
std::complex<float>,
126+
T,
127+
std::complex<float>,
128+
std::complex<float>>,
129+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
130+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
131+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
132+
};
133+
109134
/**
110135
* @brief A factory to define pairs of supported types for which
111136
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "cos.hpp"
3535
#include "div.hpp"
3636
#include "ln.hpp"
37+
#include "pow.hpp"
3738
#include "sin.hpp"
3839
#include "sqr.hpp"
3940
#include "sqrt.hpp"
@@ -46,6 +47,7 @@ using vm_ext::binary_impl_fn_ptr_t;
4647
using vm_ext::unary_impl_fn_ptr_t;
4748

4849
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
50+
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
4951

5052
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5153
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
@@ -144,6 +146,36 @@ PYBIND11_MODULE(_vm_impl, m)
144146
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
145147
}
146148

149+
// BinaryUfunc: ==== Pow(x1, x2) ====
150+
{
151+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
152+
vm_ext::PowContigFactory>(
153+
pow_dispatch_vector);
154+
155+
auto pow_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
156+
arrayT dst, const event_vecT &depends = {}) {
157+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
158+
pow_dispatch_vector);
159+
};
160+
m.def("_pow", pow_pyapi,
161+
"Call `pow` function from OneMKL VM library to performs element "
162+
"by element exponentiation of vector `src1` raised to the power "
163+
"of vector `src2` to resulting vector `dst`",
164+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
165+
py::arg("dst"), py::arg("depends") = py::list());
166+
167+
auto pow_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
168+
arrayT src2, arrayT dst) {
169+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
170+
pow_dispatch_vector);
171+
};
172+
m.def("_mkl_pow_to_call", pow_need_to_call_pyapi,
173+
"Check input arguments to answer if `pow` function from "
174+
"OneMKL VM library can be used",
175+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
176+
py::arg("dst"));
177+
}
178+
147179
// UnaryUfunc: ==== Sin(x) ====
148180
{
149181
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,15 +869,31 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
869869

870870

871871
def dpnp_power(x1, x2, out=None, order="K"):
872-
"""Invokes pow() from dpctl.tensor implementation for power() function."""
872+
"""
873+
Invokes pow() function from pybind11 extension of OneMKL VM if possible.
874+
875+
Otherwise fully relies on dpctl.tensor implementation for pow() function.
876+
877+
"""
878+
879+
def _call_pow(src1, src2, dst, sycl_queue, depends=None):
880+
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""
881+
882+
if depends is None:
883+
depends = []
884+
885+
if vmi._mkl_pow_to_call(sycl_queue, src1, src2, dst):
886+
# call pybind11 extension for pow() function from OneMKL VM
887+
return vmi._pow(sycl_queue, src1, src2, dst, depends)
888+
return ti._pow(src1, src2, dst, sycl_queue, depends)
873889

874890
# dpctl.tensor only works with usm_ndarray or scalar
875891
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
876892
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
877893
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
878894

879895
func = BinaryElementwiseFunc(
880-
"pow", ti._pow_result_type, ti._pow, _power_docstring_
896+
"pow", ti._pow_result_type, _call_pow, _power_docstring_
881897
)
882898
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
883899
return dpnp_array._create_from_usm_ndarray(res_usm)

0 commit comments

Comments
 (0)