@@ -50,7 +50,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
50
50
typedef sycl::event (*kaiser_fn_ptr_t )(sycl::queue &,
51
51
char *,
52
52
const std::size_t ,
53
- const float ,
53
+ const py::object & ,
54
54
const std::vector<sycl::event> &);
55
55
56
56
static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
@@ -61,10 +61,10 @@ class KaiserFunctor
61
61
private:
62
62
T *data = nullptr ;
63
63
const std::size_t N;
64
- const float beta;
64
+ const T beta;
65
65
66
66
public:
67
- KaiserFunctor (T *data, const std::size_t N, const float beta)
67
+ KaiserFunctor (T *data, const std::size_t N, const T beta)
68
68
: data(data), N(N), beta(beta)
69
69
{
70
70
}
@@ -89,12 +89,13 @@ template <typename T, template <typename> class Functor>
89
89
sycl::event kaiser_impl (sycl::queue &q,
90
90
char *result,
91
91
const std::size_t nelems,
92
- const float beta ,
92
+ const py::object &py_beta ,
93
93
const std::vector<sycl::event> &depends)
94
94
{
95
95
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
96
96
97
97
T *res = reinterpret_cast <T *>(result);
98
+ const T beta = py::cast<const T>(py_beta);
98
99
99
100
sycl::event kaiser_ev = q.submit ([&](sycl::handler &cgh) {
100
101
cgh.depends_on (depends);
@@ -123,7 +124,7 @@ struct KaiserFactory
123
124
124
125
std::pair<sycl::event, sycl::event>
125
126
py_kaiser (sycl::queue &exec_q,
126
- const float beta ,
127
+ const py::object &py_beta ,
127
128
const dpctl::tensor::usm_ndarray &result,
128
129
const std::vector<sycl::event> &depends)
129
130
{
@@ -160,7 +161,7 @@ std::pair<sycl::event, sycl::event>
160
161
161
162
char *result_typeless_ptr = result.get_data ();
162
163
sycl::event kaiser_ev =
163
- fn (exec_q, result_typeless_ptr, nelems, beta , depends);
164
+ fn (exec_q, result_typeless_ptr, nelems, py_beta , depends);
164
165
sycl::event args_ev =
165
166
dpctl::utils::keep_args_alive (exec_q, {result}, {kaiser_ev});
166
167
0 commit comments