Skip to content

Commit 4ddd18c

Browse files
authored
Merge 6fe1e2b into 76f4360
2 parents 76f4360 + 6fe1e2b commit 4ddd18c

File tree

10 files changed

+441
-14
lines changed

10 files changed

+441
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
* Added implementation of `dpnp.hanning` [#2358](https://github.com/IntelPython/dpnp/pull/2358)
1313
* Added implementation of `dpnp.blackman` [#2363](https://github.com/IntelPython/dpnp/pull/2363)
1414
* Added implementation of `dpnp.bartlett` [#2366](https://github.com/IntelPython/dpnp/pull/2366)
15+
* Added implementation of `dpnp.kaiser` [#2387](https://github.com/IntelPython/dpnp/pull/2387)
1516

1617
### Changed
1718

dpnp/backend/extensions/window/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
set(python_module_name _window_impl)
2828
set(_module_src
29+
${CMAKE_CURRENT_SOURCE_DIR}/kaiser.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
3031
)
3132

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include "kaiser.hpp"
27+
#include "utils/output_validation.hpp"
28+
#include "utils/type_dispatch.hpp"
29+
#include "utils/type_utils.hpp"
30+
#include <sycl/sycl.hpp>
31+
32+
/**
33+
* Version of SYCL DPC++ 2025.1 compiler where an issue with
34+
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
35+
*/
36+
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
37+
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
38+
#endif
39+
40+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
41+
#include <sycl/ext/intel/math.hpp>
42+
#endif
43+
44+
#include "../kernels/elementwise_functions/i0.hpp"
45+
46+
namespace dpnp::extensions::window
47+
{
48+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
49+
50+
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
51+
char *,
52+
const std::size_t,
53+
const py::object &,
54+
const std::vector<sycl::event> &);
55+
56+
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
57+
58+
template <typename T>
59+
class KaiserFunctor
60+
{
61+
private:
62+
T *data = nullptr;
63+
const std::size_t N;
64+
const T beta;
65+
66+
public:
67+
KaiserFunctor(T *data, const std::size_t N, const T beta)
68+
: data(data), N(N), beta(beta)
69+
{
70+
}
71+
72+
void operator()(sycl::id<1> id) const
73+
{
74+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
75+
using sycl::ext::intel::math::cyl_bessel_i0;
76+
#else
77+
using dpnp::kernels::i0::impl::cyl_bessel_i0;
78+
#endif
79+
80+
const auto i = id.get(0);
81+
const T alpha = (N - 1) / T(2);
82+
const T tmp = (i - alpha) / alpha;
83+
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
84+
cyl_bessel_i0(beta);
85+
}
86+
};
87+
88+
template <typename T, template <typename> class Functor>
89+
sycl::event kaiser_impl(sycl::queue &q,
90+
char *result,
91+
const std::size_t nelems,
92+
const py::object &py_beta,
93+
const std::vector<sycl::event> &depends)
94+
{
95+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
96+
97+
T *res = reinterpret_cast<T *>(result);
98+
const T beta = py::cast<const T>(py_beta);
99+
100+
sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
101+
cgh.depends_on(depends);
102+
103+
using KaiserKernel = Functor<T>;
104+
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
105+
KaiserKernel(res, nelems, beta));
106+
});
107+
108+
return kaiser_ev;
109+
}
110+
111+
template <typename fnT, typename T>
112+
struct KaiserFactory
113+
{
114+
fnT get()
115+
{
116+
if constexpr (std::is_floating_point_v<T>) {
117+
return kaiser_impl<T, KaiserFunctor>;
118+
}
119+
else {
120+
return nullptr;
121+
}
122+
}
123+
};
124+
125+
std::pair<sycl::event, sycl::event>
126+
py_kaiser(sycl::queue &exec_q,
127+
const py::object &py_beta,
128+
const dpctl::tensor::usm_ndarray &result,
129+
const std::vector<sycl::event> &depends)
130+
{
131+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
132+
133+
int nd = result.get_ndim();
134+
if (nd != 1) {
135+
throw py::value_error("Array should be 1d");
136+
}
137+
138+
if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
139+
throw py::value_error(
140+
"Execution queue is not compatible with allocation queue.");
141+
}
142+
143+
const bool is_result_c_contig = result.is_c_contiguous();
144+
if (!is_result_c_contig) {
145+
throw py::value_error("The result input array is not c-contiguous.");
146+
}
147+
148+
size_t nelems = result.get_size();
149+
if (nelems == 0) {
150+
return std::make_pair(sycl::event{}, sycl::event{});
151+
}
152+
153+
int result_typenum = result.get_typenum();
154+
auto array_types = dpctl_td_ns::usm_ndarray_types();
155+
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
156+
auto fn = kaiser_dispatch_vector[result_type_id];
157+
158+
if (fn == nullptr) {
159+
throw std::runtime_error("Type of given array is not supported");
160+
}
161+
162+
char *result_typeless_ptr = result.get_data();
163+
sycl::event kaiser_ev =
164+
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
165+
sycl::event args_ev =
166+
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});
167+
168+
return std::make_pair(args_ev, kaiser_ev);
169+
}
170+
171+
void init_kaiser_dispatch_vectors()
172+
{
173+
dpctl_td_ns::DispatchVectorBuilder<kaiser_fn_ptr_t, KaiserFactory,
174+
dpctl_td_ns::num_types>
175+
contig;
176+
contig.populate_dispatch_vector(kaiser_dispatch_vector);
177+
178+
return;
179+
}
180+
181+
} // namespace dpnp::extensions::window
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <dpctl4pybind11.hpp>
29+
#include <sycl/sycl.hpp>
30+
31+
namespace dpnp::extensions::window
32+
{
33+
extern std::pair<sycl::event, sycl::event>
34+
py_kaiser(sycl::queue &exec_q,
35+
const py::object &beta,
36+
const dpctl::tensor::usm_ndarray &result,
37+
const std::vector<sycl::event> &depends);
38+
39+
extern void init_kaiser_dispatch_vectors(void);
40+
41+
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/window_py.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "common.hpp"
3636
#include "hamming.hpp"
3737
#include "hanning.hpp"
38+
#include "kaiser.hpp"
3839

3940
namespace window_ns = dpnp::extensions::window;
4041
namespace py = pybind11;
@@ -111,4 +112,12 @@ PYBIND11_MODULE(_window_impl, m)
111112
py::arg("sycl_queue"), py::arg("result"),
112113
py::arg("depends") = py::list());
113114
}
115+
116+
{
117+
window_ns::init_kaiser_dispatch_vectors();
118+
119+
m.def("_kaiser", window_ns::py_kaiser, "Call Kaiser kernel",
120+
py::arg("sycl_queue"), py::arg("beta"), py::arg("result"),
121+
py::arg("depends") = py::list());
122+
}
114123
}

0 commit comments

Comments
 (0)