Skip to content

Commit 18ea734

Browse files
Merge 051dc50 into e5ab49b
2 parents e5ab49b + 051dc50 commit 18ea734

File tree

12 files changed

+893
-35
lines changed

12 files changed

+893
-35
lines changed

dpnp/backend/extensions/common/ext/common.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ struct IsNan
106106
}
107107
};
108108

109+
template <typename T>
110+
struct value_type_of
111+
{
112+
using type = T;
113+
};
114+
115+
template <typename T>
116+
struct value_type_of<std::complex<T>>
117+
{
118+
using type = T;
119+
};
120+
109121
size_t get_max_local_size(const sycl::device &device);
110122
size_t get_max_local_size(const sycl::device &device,
111123
int cpu_local_size_limit,

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(_elementwise_sources
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
@@ -69,6 +70,7 @@ endif()
6970
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
7071

7172
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
73+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../common)
7274

7375
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIR})
7476
target_include_directories(${python_module_name} PUBLIC ${Dpctl_TENSOR_INCLUDE_DIR})

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "gcd.hpp"
3737
#include "heaviside.hpp"
3838
#include "i0.hpp"
39+
#include "interpolate.hpp"
3940
#include "lcm.hpp"
4041
#include "ldexp.hpp"
4142
#include "logaddexp2.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_gcd(m);
6566
init_heaviside(m);
6667
init_i0(m);
68+
init_interpolate(m);
6769
init_lcm(m);
6870
init_ldexp(m);
6971
init_logaddexp2(m);
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <complex>
27+
#include <vector>
28+
29+
#include "dpctl4pybind11.hpp"
30+
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
32+
33+
// dpctl tensor headers
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "kernels/elementwise_functions/interpolate.hpp"
38+
39+
#include "ext/common.hpp"
40+
#include "ext/validation_utils.hpp"
41+
42+
namespace py = pybind11;
43+
namespace td_ns = dpctl::tensor::type_dispatch;
44+
namespace type_utils = dpctl::tensor::type_utils;
45+
46+
using ext::common::value_type_of;
47+
using ext::validation::array_names;
48+
using ext::validation::array_ptr;
49+
using ext::validation::common_checks;
50+
51+
namespace dpnp::extensions::ufunc
52+
{
53+
54+
namespace impl
55+
{
56+
57+
template <typename T>
58+
using value_type_of_t = typename value_type_of<T>::type;
59+
60+
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
61+
const void *, // x
62+
const void *, // idx
63+
const void *, // xp
64+
const void *, // fp
65+
const void *, // left
66+
const void *, // right
67+
void *, // out
68+
const std::size_t, // n
69+
const std::size_t, // xp_size
70+
const std::vector<sycl::event> &);
71+
72+
template <typename T, typename TIdx = std::int64_t>
73+
sycl::event interpolate_call(sycl::queue &exec_q,
74+
const void *vx,
75+
const void *vidx,
76+
const void *vxp,
77+
const void *vfp,
78+
const void *vleft,
79+
const void *vright,
80+
void *vout,
81+
const std::size_t n,
82+
const std::size_t xp_size,
83+
const std::vector<sycl::event> &depends)
84+
{
85+
using type_utils::is_complex_v;
86+
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
87+
88+
const TCoord *x = static_cast<const TCoord *>(vx);
89+
const TIdx *idx = static_cast<const TIdx *>(vidx);
90+
const TCoord *xp = static_cast<const TCoord *>(vxp);
91+
const T *fp = static_cast<const T *>(vfp);
92+
const T *left = static_cast<const T *>(vleft);
93+
const T *right = static_cast<const T *>(vright);
94+
T *out = static_cast<T *>(vout);
95+
96+
using dpnp::kernels::interpolate::interpolate_impl;
97+
sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
98+
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);
99+
100+
return interpolate_ev;
101+
}
102+
103+
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
104+
105+
void common_interpolate_checks(
106+
const dpctl::tensor::usm_ndarray &x,
107+
const dpctl::tensor::usm_ndarray &idx,
108+
const dpctl::tensor::usm_ndarray &xp,
109+
const dpctl::tensor::usm_ndarray &fp,
110+
const dpctl::tensor::usm_ndarray &out,
111+
const std::optional<const dpctl::tensor::usm_ndarray> &left,
112+
const std::optional<const dpctl::tensor::usm_ndarray> &right)
113+
{
114+
array_names names = {{&x, "x"}, {&xp, "xp"}, {&fp, "fp"}, {&out, "out"}};
115+
116+
auto array_types = td_ns::usm_ndarray_types();
117+
int x_type_id = array_types.typenum_to_lookup_id(x.get_typenum());
118+
int idx_type_id = array_types.typenum_to_lookup_id(idx.get_typenum());
119+
int xp_type_id = array_types.typenum_to_lookup_id(xp.get_typenum());
120+
int fp_type_id = array_types.typenum_to_lookup_id(fp.get_typenum());
121+
int out_type_id = array_types.typenum_to_lookup_id(out.get_typenum());
122+
123+
if (x_type_id != xp_type_id) {
124+
throw py::value_error("x and xp must have the same dtype");
125+
}
126+
if (fp_type_id != out_type_id) {
127+
throw py::value_error("fp and out must have the same dtype");
128+
}
129+
if (idx_type_id != static_cast<int>(td_ns::typenum_t::INT64)) {
130+
throw py::value_error("The type of idx must be int64");
131+
}
132+
133+
auto left_v = left ? &left.value() : nullptr;
134+
if (left_v) {
135+
names.insert({left_v, "left"});
136+
if (left_v->get_ndim() != 0) {
137+
throw py::value_error("left must be a zero-dimensional array");
138+
}
139+
140+
int left_type_id =
141+
array_types.typenum_to_lookup_id(left_v->get_typenum());
142+
if (left_type_id != fp_type_id) {
143+
throw py::value_error(
144+
"left must have the same dtype as fp and out");
145+
}
146+
}
147+
148+
auto right_v = right ? &right.value() : nullptr;
149+
if (right_v) {
150+
names.insert({right_v, "right"});
151+
if (right_v->get_ndim() != 0) {
152+
throw py::value_error("right must be a zero-dimensional array");
153+
}
154+
155+
int right_type_id =
156+
array_types.typenum_to_lookup_id(right_v->get_typenum());
157+
if (right_type_id != fp_type_id) {
158+
throw py::value_error(
159+
"right must have the same dtype as fp and out");
160+
}
161+
}
162+
163+
common_checks({&x, &xp, &fp, left_v, right_v}, {&out}, names);
164+
165+
if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 ||
166+
idx.get_ndim() != 1 || out.get_ndim() != 1)
167+
{
168+
throw py::value_error("All arrays must be one-dimensional");
169+
}
170+
171+
if (xp.get_size() != fp.get_size()) {
172+
throw py::value_error("xp and fp must have the same size");
173+
}
174+
175+
if (xp.get_size() == 0) {
176+
throw py::value_error("array of sample points is empty");
177+
}
178+
179+
if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) {
180+
throw py::value_error("x, idx, and out must have the same size");
181+
}
182+
}
183+
184+
std::pair<sycl::event, sycl::event>
185+
py_interpolate(const dpctl::tensor::usm_ndarray &x,
186+
const dpctl::tensor::usm_ndarray &idx,
187+
const dpctl::tensor::usm_ndarray &xp,
188+
const dpctl::tensor::usm_ndarray &fp,
189+
std::optional<const dpctl::tensor::usm_ndarray> &left,
190+
std::optional<const dpctl::tensor::usm_ndarray> &right,
191+
dpctl::tensor::usm_ndarray &out,
192+
sycl::queue &exec_q,
193+
const std::vector<sycl::event> &depends)
194+
{
195+
common_interpolate_checks(x, idx, xp, fp, out, left, right);
196+
197+
if (x.get_size() == 0) {
198+
return {sycl::event(), sycl::event()};
199+
}
200+
201+
int out_typenum = out.get_typenum();
202+
203+
auto array_types = td_ns::usm_ndarray_types();
204+
int out_type_id = array_types.typenum_to_lookup_id(out_typenum);
205+
206+
auto fn = interpolate_dispatch_vector[out_type_id];
207+
if (!fn) {
208+
throw py::type_error("Unsupported dtype");
209+
}
210+
211+
std::size_t n = x.get_size();
212+
std::size_t xp_size = xp.get_size();
213+
214+
void *left_ptr = left ? left.value().get_data() : nullptr;
215+
void *right_ptr = right ? right.value().get_data() : nullptr;
216+
217+
sycl::event ev =
218+
fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(),
219+
left_ptr, right_ptr, out.get_data(), n, xp_size, depends);
220+
221+
sycl::event args_ev;
222+
223+
if (left && right) {
224+
args_ev = dpctl::utils::keep_args_alive(
225+
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev});
226+
}
227+
else if (left || right) {
228+
args_ev = dpctl::utils::keep_args_alive(
229+
exec_q, {x, idx, xp, fp, out, left ? left.value() : right.value()},
230+
{ev});
231+
}
232+
else {
233+
args_ev =
234+
dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev});
235+
}
236+
237+
return std::make_pair(args_ev, ev);
238+
}
239+
240+
/**
241+
* @brief A factory to define pairs of supported types for which
242+
* interpolate function is available.
243+
*
244+
* @tparam T Type of input vector `a` and of result vector `y`.
245+
*/
246+
template <typename T>
247+
struct InterpolateOutputType
248+
{
249+
using value_type = typename std::disjunction<
250+
td_ns::TypeMapResultEntry<T, sycl::half>,
251+
td_ns::TypeMapResultEntry<T, float>,
252+
td_ns::TypeMapResultEntry<T, double>,
253+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
254+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
255+
td_ns::DefaultResultEntry<void>>::result_type;
256+
};
257+
258+
template <typename fnT, typename T>
259+
struct InterpolateFactory
260+
{
261+
fnT get()
262+
{
263+
if constexpr (std::is_same_v<
264+
typename InterpolateOutputType<T>::value_type, void>)
265+
{
266+
return nullptr;
267+
}
268+
else {
269+
return interpolate_call<T>;
270+
}
271+
}
272+
};
273+
274+
void init_interpolate_dispatch_vectors()
275+
{
276+
using namespace td_ns;
277+
278+
DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types>
279+
dtb_interpolate;
280+
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector);
281+
}
282+
283+
} // namespace impl
284+
285+
void init_interpolate(py::module_ m)
286+
{
287+
impl::init_interpolate_dispatch_vectors();
288+
289+
using impl::py_interpolate;
290+
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"),
291+
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"),
292+
py::arg("out"), py::arg("sycl_queue"),
293+
py::arg("depends") = py::list());
294+
}
295+
296+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_interpolate(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)