Skip to content

Commit b394d2f

Browse files
committed
reuse window_fn
1 parent 672d0f8 commit b394d2f

File tree

2 files changed

+24
-33
lines changed

2 files changed

+24
-33
lines changed

dpnp/backend/extensions/window/common.hpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
6767
return window_ev;
6868
}
6969

70-
inline std::pair<sycl::event, sycl::event>
71-
py_window(sycl::queue &exec_q,
70+
template <typename funcPtrT>
71+
std::tuple<size_t, char *, funcPtrT>
72+
window_fn(sycl::queue &exec_q,
7273
const dpctl::tensor::usm_ndarray &result,
73-
const std::vector<sycl::event> &depends,
74-
const window_fn_ptr_t *window_dispatch_vector)
74+
const funcPtrT *window_dispatch_vector)
7575
{
7676
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
7777

@@ -92,19 +92,35 @@ inline std::pair<sycl::event, sycl::event>
9292

9393
size_t nelems = result.get_size();
9494
if (nelems == 0) {
95-
return std::make_pair(sycl::event{}, sycl::event{});
95+
return std::make_tuple(nelems, nullptr, nullptr);
9696
}
9797

9898
int result_typenum = result.get_typenum();
9999
auto array_types = dpctl_td_ns::usm_ndarray_types();
100100
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
101-
auto fn = window_dispatch_vector[result_type_id];
101+
funcPtrT fn = window_dispatch_vector[result_type_id];
102102

103103
if (fn == nullptr) {
104104
throw std::runtime_error("Type of given array is not supported");
105105
}
106106

107107
char *result_typeless_ptr = result.get_data();
108+
return std::make_tuple(nelems, result_typeless_ptr, fn);
109+
}
110+
111+
inline std::pair<sycl::event, sycl::event>
112+
py_window(sycl::queue &exec_q,
113+
const dpctl::tensor::usm_ndarray &result,
114+
const std::vector<sycl::event> &depends,
115+
const window_fn_ptr_t *window_dispatch_vector)
116+
{
117+
auto [nelems, result_typeless_ptr, fn] =
118+
window_fn<window_fn_ptr_t>(exec_q, result, window_dispatch_vector);
119+
120+
if (nelems == 0) {
121+
return std::make_pair(sycl::event{}, sycl::event{});
122+
}
123+
108124
sycl::event window_ev = fn(exec_q, result_typeless_ptr, nelems, depends);
109125
sycl::event args_ev =
110126
dpctl::utils::keep_args_alive(exec_q, {result}, {window_ev});

dpnp/backend/extensions/window/kaiser.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -131,38 +131,13 @@ std::pair<sycl::event, sycl::event>
131131
const dpctl::tensor::usm_ndarray &result,
132132
const std::vector<sycl::event> &depends)
133133
{
134-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
134+
auto [nelems, result_typeless_ptr, fn] =
135+
window_fn<kaiser_fn_ptr_t>(exec_q, result, kaiser_dispatch_vector);
135136

136-
int nd = result.get_ndim();
137-
if (nd != 1) {
138-
throw py::value_error("Array should be 1d");
139-
}
140-
141-
if (!dpctl::utils::queues_are_compatible(exec_q, {result.get_queue()})) {
142-
throw py::value_error(
143-
"Execution queue is not compatible with allocation queue.");
144-
}
145-
146-
const bool is_result_c_contig = result.is_c_contiguous();
147-
if (!is_result_c_contig) {
148-
throw py::value_error("The result input array is not c-contiguous.");
149-
}
150-
151-
size_t nelems = result.get_size();
152137
if (nelems == 0) {
153138
return std::make_pair(sycl::event{}, sycl::event{});
154139
}
155140

156-
int result_typenum = result.get_typenum();
157-
auto array_types = dpctl_td_ns::usm_ndarray_types();
158-
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
159-
auto fn = kaiser_dispatch_vector[result_type_id];
160-
161-
if (fn == nullptr) {
162-
throw std::runtime_error("Type of given array is not supported");
163-
}
164-
165-
char *result_typeless_ptr = result.get_data();
166141
sycl::event kaiser_ev =
167142
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
168143
sycl::event args_ev =

0 commit comments

Comments
 (0)