Skip to content

Commit 6fe1e2b

Browse files
committed
pass beta as python object
1 parent 24c10d9 commit 6fe1e2b

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

dpnp/backend/extensions/window/kaiser.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5050
typedef sycl::event (*kaiser_fn_ptr_t)(sycl::queue &,
5151
char *,
5252
const std::size_t,
53-
const float,
53+
const py::object &,
5454
const std::vector<sycl::event> &);
5555

5656
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
@@ -61,10 +61,10 @@ class KaiserFunctor
6161
private:
6262
T *data = nullptr;
6363
const std::size_t N;
64-
const float beta;
64+
const T beta;
6565

6666
public:
67-
KaiserFunctor(T *data, const std::size_t N, const float beta)
67+
KaiserFunctor(T *data, const std::size_t N, const T beta)
6868
: data(data), N(N), beta(beta)
6969
{
7070
}
@@ -89,12 +89,13 @@ template <typename T, template <typename> class Functor>
8989
sycl::event kaiser_impl(sycl::queue &q,
9090
char *result,
9191
const std::size_t nelems,
92-
const float beta,
92+
const py::object &py_beta,
9393
const std::vector<sycl::event> &depends)
9494
{
9595
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
9696

9797
T *res = reinterpret_cast<T *>(result);
98+
const T beta = py::cast<const T>(py_beta);
9899

99100
sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
100101
cgh.depends_on(depends);
@@ -123,7 +124,7 @@ struct KaiserFactory
123124

124125
std::pair<sycl::event, sycl::event>
125126
py_kaiser(sycl::queue &exec_q,
126-
const float beta,
127+
const py::object &py_beta,
127128
const dpctl::tensor::usm_ndarray &result,
128129
const std::vector<sycl::event> &depends)
129130
{
@@ -160,7 +161,7 @@ std::pair<sycl::event, sycl::event>
160161

161162
char *result_typeless_ptr = result.get_data();
162163
sycl::event kaiser_ev =
163-
fn(exec_q, result_typeless_ptr, nelems, beta, depends);
164+
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
164165
sycl::event args_ev =
165166
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});
166167

dpnp/backend/extensions/window/kaiser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace dpnp::extensions::window
3232
{
3333
extern std::pair<sycl::event, sycl::event>
3434
py_kaiser(sycl::queue &exec_q,
35-
const float beta,
35+
const py::object &beta,
3636
const dpctl::tensor::usm_ndarray &result,
3737
const std::vector<sycl::event> &depends);
3838

0 commit comments

Comments
 (0)