Skip to content

Commit 2b65082

Browse files
Merge 5ec0738 into 4c66b58
2 parents 4c66b58 + 5ec0738 commit 2b65082

File tree

8 files changed

+696
-12
lines changed

8 files changed

+696
-12
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ set(_elementwise_sources
3636
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/gcd.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/heaviside.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/i0.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/interpolate.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "gcd.hpp"
3737
#include "heaviside.hpp"
3838
#include "i0.hpp"
39+
#include "interpolate.hpp"
3940
#include "lcm.hpp"
4041
#include "ldexp.hpp"
4142
#include "logaddexp2.hpp"
@@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m)
6465
init_gcd(m);
6566
init_heaviside(m);
6667
init_i0(m);
68+
init_interpolate(m);
6769
init_lcm(m);
6870
init_ldexp(m);
6971
init_logaddexp2(m);
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <complex>
27+
#include <vector>
28+
29+
#include "dpctl4pybind11.hpp"
30+
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
32+
33+
// dpctl tensor headers
34+
#include "utils/output_validation.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
37+
#include "kernels/elementwise_functions/interpolate.hpp"
38+
39+
namespace py = pybind11;
40+
namespace td_ns = dpctl::tensor::type_dispatch;
41+
42+
namespace dpnp::extensions::ufunc
43+
{
44+
45+
namespace impl
46+
{
47+
48+
template <typename T>
49+
struct value_type_of
50+
{
51+
using type = T;
52+
};
53+
54+
template <typename T>
55+
struct value_type_of<std::complex<T>>
56+
{
57+
using type = T;
58+
};
59+
60+
template <typename T>
61+
using value_type_of_t = typename value_type_of<T>::type;
62+
63+
typedef sycl::event (*interpolate_fn_ptr_t)(sycl::queue &,
64+
const void *, // x
65+
const void *, // idx
66+
const void *, // xp
67+
const void *, // fp
68+
const void *, // left
69+
const void *, // right
70+
void *, // out
71+
std::size_t, // n
72+
std::size_t, // xp_size
73+
const std::vector<sycl::event> &);
74+
75+
template <typename T>
76+
sycl::event interpolate_call(sycl::queue &exec_q,
77+
const void *vx,
78+
const void *vidx,
79+
const void *vxp,
80+
const void *vfp,
81+
const void *vleft,
82+
const void *vright,
83+
void *vout,
84+
std::size_t n,
85+
std::size_t xp_size,
86+
const std::vector<sycl::event> &depends)
87+
{
88+
using dpctl::tensor::type_utils::is_complex_v;
89+
using TCoord = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
90+
91+
const TCoord *x = static_cast<const TCoord *>(vx);
92+
const std::int64_t *idx = static_cast<const std::int64_t *>(vidx);
93+
const TCoord *xp = static_cast<const TCoord *>(vxp);
94+
const T *fp = static_cast<const T *>(vfp);
95+
const T *left = static_cast<const T *>(vleft);
96+
const T *right = static_cast<const T *>(vright);
97+
T *out = static_cast<T *>(vout);
98+
99+
using dpnp::kernels::interpolate::interpolate_impl;
100+
sycl::event interpolate_ev = interpolate_impl<TCoord, T>(
101+
exec_q, x, idx, xp, fp, left, right, out, n, xp_size, depends);
102+
103+
return interpolate_ev;
104+
}
105+
106+
interpolate_fn_ptr_t interpolate_dispatch_vector[td_ns::num_types];
107+
108+
std::pair<sycl::event, sycl::event>
109+
py_interpolate(const dpctl::tensor::usm_ndarray &x,
110+
const dpctl::tensor::usm_ndarray &idx,
111+
const dpctl::tensor::usm_ndarray &xp,
112+
const dpctl::tensor::usm_ndarray &fp,
113+
std::optional<const dpctl::tensor::usm_ndarray> &left,
114+
std::optional<const dpctl::tensor::usm_ndarray> &right,
115+
dpctl::tensor::usm_ndarray &out,
116+
sycl::queue &exec_q,
117+
const std::vector<sycl::event> &depends)
118+
{
119+
int x_typenum = x.get_typenum();
120+
int xp_typenum = xp.get_typenum();
121+
int fp_typenum = fp.get_typenum();
122+
int out_typenum = out.get_typenum();
123+
124+
auto array_types = td_ns::usm_ndarray_types();
125+
int x_type_id = array_types.typenum_to_lookup_id(x_typenum);
126+
int xp_type_id = array_types.typenum_to_lookup_id(xp_typenum);
127+
int fp_type_id = array_types.typenum_to_lookup_id(fp_typenum);
128+
int out_type_id = array_types.typenum_to_lookup_id(out_typenum);
129+
130+
if (x_type_id != xp_type_id) {
131+
throw py::value_error("x and xp must have the same dtype");
132+
}
133+
if (fp_type_id != out_type_id) {
134+
throw py::value_error("fp and out must have the same dtype");
135+
}
136+
137+
auto fn = interpolate_dispatch_vector[fp_type_id];
138+
if (!fn) {
139+
throw py::type_error("Unsupported dtype");
140+
}
141+
142+
if (!dpctl::utils::queues_are_compatible(exec_q, {x, idx, xp, fp, out})) {
143+
throw py::value_error(
144+
"Execution queue is not compatible with allocation queues");
145+
}
146+
147+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out);
148+
149+
if (x.get_ndim() != 1 || xp.get_ndim() != 1 || fp.get_ndim() != 1 ||
150+
idx.get_ndim() != 1 || out.get_ndim() != 1)
151+
{
152+
throw py::value_error("All arrays must be one-dimensional");
153+
}
154+
155+
if (xp.get_size() != fp.get_size()) {
156+
throw py::value_error("xp and fp must have the same size");
157+
}
158+
159+
if (x.get_size() != out.get_size() || x.get_size() != idx.get_size()) {
160+
throw py::value_error("x, idx, and out must have the same size");
161+
}
162+
163+
std::size_t n = x.get_size();
164+
std::size_t xp_size = xp.get_size();
165+
166+
void *left_ptr = left.has_value() ? left.value().get_data() : nullptr;
167+
168+
void *right_ptr = right.has_value() ? right.value().get_data() : nullptr;
169+
170+
sycl::event ev =
171+
fn(exec_q, x.get_data(), idx.get_data(), xp.get_data(), fp.get_data(),
172+
left_ptr, right_ptr, out.get_data(), n, xp_size, depends);
173+
174+
sycl::event args_ev;
175+
176+
if (left.has_value() && right.has_value()) {
177+
args_ev = dpctl::utils::keep_args_alive(
178+
exec_q, {x, idx, xp, fp, out, left.value(), right.value()}, {ev});
179+
}
180+
else if (left.has_value()) {
181+
args_ev = dpctl::utils::keep_args_alive(
182+
exec_q, {x, idx, xp, fp, out, left.value()}, {ev});
183+
}
184+
else if (right.has_value()) {
185+
args_ev = dpctl::utils::keep_args_alive(
186+
exec_q, {x, idx, xp, fp, out, right.value()}, {ev});
187+
}
188+
else {
189+
args_ev =
190+
dpctl::utils::keep_args_alive(exec_q, {x, idx, xp, fp, out}, {ev});
191+
}
192+
193+
return std::make_pair(args_ev, ev);
194+
}
195+
196+
/**
197+
* @brief A factory to define pairs of supported types for which
198+
* interpolate function is available.
199+
*
200+
* @tparam T Type of input vector `a` and of result vector `y`.
201+
*/
202+
template <typename T>
203+
struct InterpolateOutputType
204+
{
205+
using value_type = typename std::disjunction<
206+
td_ns::TypeMapResultEntry<T, sycl::half>,
207+
td_ns::TypeMapResultEntry<T, float>,
208+
td_ns::TypeMapResultEntry<T, double>,
209+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
210+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
211+
td_ns::DefaultResultEntry<void>>::result_type;
212+
};
213+
214+
template <typename fnT, typename T>
215+
struct InterpolateFactory
216+
{
217+
fnT get()
218+
{
219+
if constexpr (std::is_same_v<
220+
typename InterpolateOutputType<T>::value_type, void>)
221+
{
222+
return nullptr;
223+
}
224+
else {
225+
return interpolate_call<T>;
226+
}
227+
}
228+
};
229+
230+
void init_interpolate_dispatch_vectors()
231+
{
232+
using namespace td_ns;
233+
234+
DispatchVectorBuilder<interpolate_fn_ptr_t, InterpolateFactory, num_types>
235+
dtb_interpolate;
236+
dtb_interpolate.populate_dispatch_vector(interpolate_dispatch_vector);
237+
}
238+
239+
} // namespace impl
240+
241+
void init_interpolate(py::module_ m)
242+
{
243+
impl::init_interpolate_dispatch_vectors();
244+
245+
using impl::py_interpolate;
246+
m.def("_interpolate", &py_interpolate, "", py::arg("x"), py::arg("idx"),
247+
py::arg("xp"), py::arg("fp"), py::arg("left"), py::arg("right"),
248+
py::arg("out"), py::arg("sycl_queue"),
249+
py::arg("depends") = py::list());
250+
}
251+
252+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_interpolate(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc

0 commit comments

Comments
 (0)