Skip to content

Commit abb296d

Browse files
authored
Merge branch 'master' into remove-workaraound-dpctl-2030
2 parents fbca4cc + 6ff4e02 commit abb296d

File tree

12 files changed

+487
-27
lines changed

12 files changed

+487
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
* Added implementation of `dpnp.hanning` [#2358](https://github.com/IntelPython/dpnp/pull/2358)
1313
* Added implementation of `dpnp.blackman` [#2363](https://github.com/IntelPython/dpnp/pull/2363)
1414
* Added implementation of `dpnp.bartlett` [#2366](https://github.com/IntelPython/dpnp/pull/2366)
15+
* Added implementation of `dpnp.kaiser` [#2387](https://github.com/IntelPython/dpnp/pull/2387)
1516

1617
### Changed
1718

dpnp/backend/extensions/window/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
set(python_module_name _window_impl)
2828
set(_module_src
29+
${CMAKE_CURRENT_SOURCE_DIR}/kaiser.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/window_py.cpp
3031
)
3132

dpnp/backend/extensions/window/common.hpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
6767
return window_ev;
6868
}
6969

70-
std::pair<sycl::event, sycl::event>
71-
py_window(sycl::queue &exec_q,
70+
template <typename funcPtrT>
71+
std::tuple<size_t, char *, funcPtrT>
72+
window_fn(sycl::queue &exec_q,
7273
const dpctl::tensor::usm_ndarray &result,
73-
const std::vector<sycl::event> &depends,
74-
const window_fn_ptr_t *window_dispatch_vector)
74+
const funcPtrT *window_dispatch_vector)
7575
{
7676
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
7777

@@ -92,30 +92,48 @@ std::pair<sycl::event, sycl::event>
9292

9393
size_t nelems = result.get_size();
9494
if (nelems == 0) {
95-
return std::make_pair(sycl::event{}, sycl::event{});
95+
return std::make_tuple(nelems, nullptr, nullptr);
9696
}
9797

9898
int result_typenum = result.get_typenum();
9999
auto array_types = dpctl_td_ns::usm_ndarray_types();
100100
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
101-
auto fn = window_dispatch_vector[result_type_id];
101+
funcPtrT fn = window_dispatch_vector[result_type_id];
102102

103103
if (fn == nullptr) {
104104
throw std::runtime_error("Type of given array is not supported");
105105
}
106106

107107
char *result_typeless_ptr = result.get_data();
108+
return std::make_tuple(nelems, result_typeless_ptr, fn);
109+
}
110+
111+
inline std::pair<sycl::event, sycl::event>
112+
py_window(sycl::queue &exec_q,
113+
const dpctl::tensor::usm_ndarray &result,
114+
const std::vector<sycl::event> &depends,
115+
const window_fn_ptr_t *window_dispatch_vector)
116+
{
117+
auto [nelems, result_typeless_ptr, fn] =
118+
window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);
119+
120+
if (nelems == 0) {
121+
return std::make_pair(sycl::event{}, sycl::event{});
122+
}
123+
108124
sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
109125
sycl::event args_ev =
110126
dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});
111127

112128
return std::make_pair(args_ev, window_ev);
113129
}
114130

115-
template <template <typename fnT, typename T> typename factoryT>
116-
void init_window_dispatch_vectors(window_fn_ptr_t window_dispatch_vector[])
131+
template <typename funcPtrT,
132+
template <typename fnT, typename T>
133+
typename factoryT>
134+
void init_window_dispatch_vectors(funcPtrT window_dispatch_vector[])
117135
{
118-
dpctl_td_ns::DispatchVectorBuilder<window_fn_ptr_t, factoryT,
136+
dpctl_td_ns::DispatchVectorBuilder<funcPtrT, factoryT,
119137
dpctl_td_ns::num_types>
120138
contig;
121139
contig.populate_dispatch_vector(window_dispatch_vector);
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include "kaiser.hpp"
27+
#include "common.hpp"
28+
29+
#include "utils/output_validation.hpp"
30+
#include "utils/type_dispatch.hpp"
31+
#include "utils/type_utils.hpp"
32+
33+
#include <sycl/sycl.hpp>
34+
35+
/**
36+
* Version of SYCL DPC++ 2025.1 compiler where an issue with
37+
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
38+
*/
39+
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
40+
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
41+
#endif
42+
43+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
44+
#include <sycl/ext/intel/math.hpp>
45+
#endif
46+
47+
#include "../kernels/elementwise_functions/i0.hpp"
48+
49+
namespace dpnp::extensions::window
50+
{
51+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
52+
53+
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
54+
char *,
55+
const std::size_t,
56+
const py::object &,
57+
const std::vector<sycl::event> &);
58+
59+
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
60+
61+
template <typename T>
62+
class KaiserFunctor
63+
{
64+
private:
65+
T *data = nullptr;
66+
const std::size_t N;
67+
const T beta;
68+
69+
public:
70+
KaiserFunctor(T *data, const std::size_t N, const T beta)
71+
: data(data), N(N), beta(beta)
72+
{
73+
}
74+
75+
void operator()(sycl::id<1> id) const
76+
{
77+
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT
78+
using sycl::ext::intel::math::cyl_bessel_i0;
79+
#else
80+
using dpnp::kernels::i0::impl::cyl_bessel_i0;
81+
#endif
82+
83+
const auto i = id.get(0);
84+
const T alpha = (N - 1) / T(2);
85+
const T tmp = (i - alpha) / alpha;
86+
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
87+
cyl_bessel_i0(beta);
88+
}
89+
};
90+
91+
template <typename T, template <typename> class Functor>
92+
sycl::event kaiser_impl(sycl::queue &q,
93+
char *result,
94+
const std::size_t nelems,
95+
const py::object &py_beta,
96+
const std::vector<sycl::event> &depends)
97+
{
98+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
99+
100+
T *res = reinterpret_cast<T *>(result);
101+
const T beta = py::cast<const T>(py_beta);
102+
103+
sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
104+
cgh.depends_on(depends);
105+
106+
using KaiserKernel = Functor<T>;
107+
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
108+
KaiserKernel(res, nelems, beta));
109+
});
110+
111+
return kaiser_ev;
112+
}
113+
114+
template <typename fnT, typename T>
115+
struct KaiserFactory
116+
{
117+
fnT get()
118+
{
119+
if constexpr (std::is_floating_point_v<T>) {
120+
return kaiser_impl<T, KaiserFunctor>;
121+
}
122+
else {
123+
return nullptr;
124+
}
125+
}
126+
};
127+
128+
std::pair<sycl::event, sycl::event>
129+
py_kaiser(sycl::queue &exec_q,
130+
const py::object &py_beta,
131+
const dpctl::tensor::usm_ndarray &result,
132+
const std::vector<sycl::event> &depends)
133+
{
134+
auto [nelems, result_typeless_ptr, fn] =
135+
window_fn<kaiser_fn_ptr_t>(exec_q, result, kaiser_dispatch_vector);
136+
137+
if (nelems == 0) {
138+
return std::make_pair(sycl::event{}, sycl::event{});
139+
}
140+
141+
sycl::event kaiser_ev =
142+
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
143+
sycl::event args_ev =
144+
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});
145+
146+
return std::make_pair(args_ev, kaiser_ev);
147+
}
148+
149+
void init_kaiser_dispatch_vectors()
150+
{
151+
init_window_dispatch_vectors<kaiser_fn_ptr_t, KaiserFactory>(
152+
kaiser_dispatch_vector);
153+
}
154+
155+
} // namespace dpnp::extensions::window
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <dpctl4pybind11.hpp>
29+
#include <sycl/sycl.hpp>
30+
31+
namespace dpnp::extensions::window
32+
{
33+
extern std::pair<sycl::event, sycl::event>
34+
py_kaiser(sycl::queue &exec_q,
35+
const py::object &beta,
36+
const dpctl::tensor::usm_ndarray &result,
37+
const std::vector<sycl::event> &depends);
38+
39+
extern void init_kaiser_dispatch_vectors(void);
40+
41+
} // namespace dpnp::extensions::window

dpnp/backend/extensions/window/window_py.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "common.hpp"
3636
#include "hamming.hpp"
3737
#include "hanning.hpp"
38+
#include "kaiser.hpp"
3839

3940
namespace window_ns = dpnp::extensions::window;
4041
namespace py = pybind11;
@@ -54,7 +55,8 @@ PYBIND11_MODULE(_window_impl, m)
5455

5556
{
5657
window_ns::init_window_dispatch_vectors<
57-
window_ns::kernels::BartlettFactory>(bartlett_dispatch_vector);
58+
window_ns::window_fn_ptr_t, window_ns::kernels::BartlettFactory>(
59+
bartlett_dispatch_vector);
5860

5961
auto bartlett_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
6062
const event_vecT &depends = {}) {
@@ -69,7 +71,8 @@ PYBIND11_MODULE(_window_impl, m)
6971

7072
{
7173
window_ns::init_window_dispatch_vectors<
72-
window_ns::kernels::BlackmanFactory>(blackman_dispatch_vector);
74+
window_ns::window_fn_ptr_t, window_ns::kernels::BlackmanFactory>(
75+
blackman_dispatch_vector);
7376

7477
auto blackman_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
7578
const event_vecT &depends = {}) {
@@ -84,7 +87,8 @@ PYBIND11_MODULE(_window_impl, m)
8487

8588
{
8689
window_ns::init_window_dispatch_vectors<
87-
window_ns::kernels::HammingFactory>(hamming_dispatch_vector);
90+
window_ns::window_fn_ptr_t, window_ns::kernels::HammingFactory>(
91+
hamming_dispatch_vector);
8892

8993
auto hamming_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
9094
const event_vecT &depends = {}) {
@@ -99,7 +103,8 @@ PYBIND11_MODULE(_window_impl, m)
99103

100104
{
101105
window_ns::init_window_dispatch_vectors<
102-
window_ns::kernels::HanningFactory>(hanning_dispatch_vector);
106+
window_ns::window_fn_ptr_t, window_ns::kernels::HanningFactory>(
107+
hanning_dispatch_vector);
103108

104109
auto hanning_pyapi = [&](sycl::queue &exec_q, const arrayT &result,
105110
const event_vecT &depends = {}) {
@@ -111,4 +116,12 @@ PYBIND11_MODULE(_window_impl, m)
111116
py::arg("sycl_queue"), py::arg("result"),
112117
py::arg("depends") = py::list());
113118
}
119+
120+
{
121+
window_ns::init_kaiser_dispatch_vectors();
122+
123+
m.def("_kaiser", window_ns::py_kaiser, "Call Kaiser kernel",
124+
py::arg("sycl_queue"), py::arg("beta"), py::arg("result"),
125+
py::arg("depends") = py::list());
126+
}
114127
}

0 commit comments

Comments
 (0)