Skip to content

Commit 96481f9

Browse files
Merge 82c657e into 316240c
2 parents 316240c + 82c657e commit 96481f9

File tree

10 files changed

+858
-23
lines changed

10 files changed

+858
-23
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(_elementwise_sources
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
@@ -69,6 +70,7 @@ endif()
6970
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
7071

7172
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
73+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
7274

7375
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
7476
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "gcd.hpp"
3737
#include "heaviside.hpp"
3838
#include "i0.hpp"
39+
#include "interpolate.hpp"
3940
#include "lcm.hpp"
4041
#include "ldexp.hpp"
4142
#include "logaddexp2.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_gcd(m);
6566
init_heaviside(m);
6667
init_i0(m);
68+
init_interpolate(m);
6769
init_lcm(m);
6870
init_ldexp(m);
6971
init_logaddexp2(m);
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <complex>
27+
#include <vector>
28+
29+
#include "dpctl4pybind11.hpp"
30+
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
32+
33+
// dpctl tensor headers
34+
#include "utils/output_validation.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
37+
#include "kernels/elementwise_functions/interpolate.hpp"
38+
39+
#include "ext/validation_utils.hpp"
40+
41+
namespace py = pybind11;
42+
namespace td_ns = dpctl::tensor::type_dispatch;
43+
44+
using ext::validation::array_names;
45+
using ext::validation::array_ptr;
46+
using ext::validation::common_checks;
47+
48+
namespace dpnp::extensions::ufunc
49+
{
50+
51+
namespace impl
52+
{
53+
54+
template <typename T>
55+
struct value_type_of
56+
{
57+
using type = T;
58+
};
59+
60+
template <typename T>
61+
struct value_type_of<std::complex<T>>
62+
{
63+
using type = T;
64+
};
65+
66+
template <typename T>
67+
using value_type_of_t = typename value_type_of<T>::type;
68+
69+
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
70+
const void *, // x
71+
const void *, // idx
72+
const void *, // xp
73+
const void *, // fp
74+
const void *, // left
75+
const void *, // right
76+
void *, // out
77+
std::size_t, // n
78+
std::size_t, // xp_size
79+
const std::vector<sycl::event> &);
80+
81+
template <typename T>
82+
sycl::event interpolate_call(sycl::queue &exec_q,
83+
const void *vx,
84+
const void *vidx,
85+
const void *vxp,
86+
const void *vfp,
87+
const void *vleft,
88+
const void *vright,
89+
void *vout,
90+
std::size_t n,
91+
std::size_t xp_size,
92+
const std::vector<sycl::event> &depends)
93+
{
94+
using dpctl::tensor::type_utils::is_complex_v;
95+
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
96+
97+
const TCoord *x = static_cast<const TCoord *>(vx);
98+
const std::int64_t *idx = static_cast<const std::int64_t *>(vidx);
99+
const TCoord *xp = static_cast<const TCoord *>(vxp);
100+
const T *fp = static_cast<const T *>(vfp);
101+
const T *left = static_cast<const T *>(vleft);
102+
const T *right = static_cast<const T *>(vright);
103+
T *out = static_cast<T *>(vout);
104+
105+
using dpnp::kernels::interpolate::interpolate_impl;
106+
sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
107+
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);
108+
109+
return interpolate_ev;
110+
}
111+
112+
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
113+
114+
void common_interpolate_checks(
115+
const dpctl::tensor::usm_ndarray &x,
116+
const dpctl::tensor::usm_ndarray &idx,
117+
const dpctl::tensor::usm_ndarray &xp,
118+
const dpctl::tensor::usm_ndarray &fp,
119+
const dpctl::tensor::usm_ndarray &out,
120+
const std::optional<const dpctl::tensor::usm_ndarray> &left,
121+
const std::optional<const dpctl::tensor::usm_ndarray> &right)
122+
{
123+
array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}};
124+
125+
auto array_types = td_ns::usm_ndarray_types();
126+
int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum());
127+
int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum());
128+
int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum());
129+
int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum());
130+
131+
if (x_type_id != xp_type_id) {
132+
throw py::value_error("x and xp must have the same dtype");
133+
}
134+
if (fp_type_id != out_type_id) {
135+
throw py::value_error("fp and out must have the same dtype");
136+
}
137+
138+
if (left) {
139+
const auto &l = left.value();
140+
names.insert({&l, "left"});
141+
if (l.get_ndim() != 0) {
142+
throw py::value_error("left must be a zero-dimensional array");
143+
}
144+
145+
int left_type_id = array_types.typenum_to_lookup_id(l.get_typenum());
146+
if (left_type_id != fp_type_id) {
147+
throw py::value_error(
148+
"left must have the same dtype as fp and out");
149+
}
150+
}
151+
152+
if (right) {
153+
const auto &r = right.value();
154+
names.insert({&r, "right"});
155+
if (r.get_ndim() != 0) {
156+
throw py::value_error("right must be a zero-dimensional array");
157+
}
158+
159+
int right_type_id = array_types.typenum_to_lookup_id(r.get_typenum());
160+
if (right_type_id != fp_type_id) {
161+
throw py::value_error(
162+
"right must have the same dtype as fp and out");
163+
}
164+
}
165+
166+
common_checks({&x, &xp, &fp, left ? &left.value() : nullptr,
167+
right ? &right.value() : nullptr},
168+
{&out}, names);
169+
170+
if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 ||
171+
idx.get_ndim() != 1 || out.get_ndim() != 1)
172+
{
173+
throw py::value_error("All arrays must be one-dimensional");
174+
}
175+
176+
if (xp.get_size() != fp.get_size()) {
177+
throw py::value_error("xp and fp must have the same size");
178+
}
179+
180+
if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) {
181+
throw py::value_error("x, idx, and out must have the same size");
182+
}
183+
}
184+
185+
std::pair<sycl::event, sycl::event>
186+
py_interpolate(const dpctl::tensor::usm_ndarray &x,
187+
const dpctl::tensor::usm_ndarray &idx,
188+
const dpctl::tensor::usm_ndarray &xp,
189+
const dpctl::tensor::usm_ndarray &fp,
190+
std::optional<const dpctl::tensor::usm_ndarray> &left,
191+
std::optional<const dpctl::tensor::usm_ndarray> &right,
192+
dpctl::tensor::usm_ndarray &out,
193+
sycl::queue &exec_q,
194+
const std::vector<sycl::event> &depends)
195+
{
196+
if (x.get_size() == 0) {
197+
return {sycl::event(), sycl::event()};
198+
}
199+
200+
common_interpolate_checks(x, idx, xp, fp, out, left, right);
201+
202+
int out_typenum = out.get_typenum();
203+
204+
auto array_types = td_ns::usm_ndarray_types();
205+
int out_type_id = array_types.typenum_to_lookup_id(out_typenum);
206+
207+
auto fn = interpolate_dispatch_vector[out_type_id];
208+
if (!fn) {
209+
throw py::type_error("Unsupported dtype");
210+
}
211+
212+
std::size_t n = x.get_size();
213+
std::size_t xp_size = xp.get_size();
214+
215+
void *left_ptr = left ? left.value().get_data() : nullptr;
216+
void *right_ptr = right ? right.value().get_data() : nullptr;
217+
218+
sycl::event ev =
219+
fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(),
220+
left_ptr, right_ptr, out.get_data(), n, xp_size, depends);
221+
222+
sycl::event args_ev;
223+
224+
if (left && right) {
225+
args_ev = dpctl::utils::keep_args_alive(
226+
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev});
227+
}
228+
else if (left) {
229+
args_ev = dpctl::utils::keep_args_alive(
230+
exec_q, {x, idx, xp, fp, out, left.value()}, {ev});
231+
}
232+
else if (right) {
233+
args_ev = dpctl::utils::keep_args_alive(
234+
exec_q, {x, idx, xp, fp, out, right.value()}, {ev});
235+
}
236+
else {
237+
args_ev =
238+
dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev});
239+
}
240+
241+
return std::make_pair(args_ev, ev);
242+
}
243+
244+
/**
245+
* @brief A factory to define pairs of supported types for which
246+
* interpolate function is available.
247+
*
248+
* @tparam T Type of input vector `a` and of result vector `y`.
249+
*/
250+
template <typename T>
251+
struct InterpolateOutputType
252+
{
253+
using value_type = typename std::disjunction<
254+
td_ns::TypeMapResultEntry<T, sycl::half>,
255+
td_ns::TypeMapResultEntry<T, float>,
256+
td_ns::TypeMapResultEntry<T, double>,
257+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
258+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
259+
td_ns::DefaultResultEntry<void>>::result_type;
260+
};
261+
262+
template <typename fnT, typename T>
263+
struct InterpolateFactory
264+
{
265+
fnT get()
266+
{
267+
if constexpr (std::is_same_v<
268+
typename InterpolateOutputType<T>::value_type, void>)
269+
{
270+
return nullptr;
271+
}
272+
else {
273+
return interpolate_call<T>;
274+
}
275+
}
276+
};
277+
278+
void init_interpolate_dispatch_vectors()
279+
{
280+
using namespace td_ns;
281+
282+
DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types>
283+
dtb_interpolate;
284+
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector);
285+
}
286+
287+
} // namespace impl
288+
289+
void init_interpolate(py::module_ m)
290+
{
291+
impl::init_interpolate_dispatch_vectors();
292+
293+
using impl::py_interpolate;
294+
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"),
295+
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"),
296+
py::arg("out"), py::arg("sycl_queue"),
297+
py::arg("depends") = py::list());
298+
}
299+
300+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_interpolate(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)