@@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
67
67
return window_ev;
68
68
}
69
69
70
- inline std::pair<sycl::event, sycl::event>
71
- py_window (sycl::queue &exec_q,
70
+ template <typename funcPtrT>
71
+ std::tuple<size_t , char *, funcPtrT>
72
+ window_fn (sycl::queue &exec_q,
72
73
const dpctl::tensor::usm_ndarray &result,
73
- const std::vector<sycl::event> &depends,
74
- const window_fn_ptr_t *window_dispatch_vector)
74
+ const funcPtrT *window_dispatch_vector)
75
75
{
76
76
dpctl::tensor::validation::CheckWritable::throw_if_not_writable (result);
77
77
@@ -92,19 +92,35 @@ inline std::pair<sycl::event, sycl::event>
92
92
93
93
size_t nelems = result.get_size ();
94
94
if (nelems == 0 ) {
95
- return std::make_pair (sycl::event{}, sycl::event{} );
95
+ return std::make_tuple (nelems, nullptr , nullptr );
96
96
}
97
97
98
98
int result_typenum = result.get_typenum ();
99
99
auto array_types = dpctl_td_ns::usm_ndarray_types ();
100
100
int result_type_id = array_types.typenum_to_lookup_id (result_typenum);
101
- auto fn = window_dispatch_vector[result_type_id];
101
+ funcPtrT fn = window_dispatch_vector[result_type_id];
102
102
103
103
if (fn == nullptr ) {
104
104
throw std::runtime_error (" Type of given array is not supported" );
105
105
}
106
106
107
107
char *result_typeless_ptr = result.get_data ();
108
+ return std::make_tuple (nelems, result_typeless_ptr, fn);
109
+ }
110
+
111
+ inline std::pair<sycl::event, sycl::event>
112
+ py_window (sycl::queue &exec_q,
113
+ const dpctl::tensor::usm_ndarray &result,
114
+ const std::vector<sycl::event> &depends,
115
+ const window_fn_ptr_t *window_dispatch_vector)
116
+ {
117
+ auto [nelems, result_typeless_ptr, fn] =
118
+ window_fn<window_fn_ptr_t >(exec_q, result, window_dispatch_vector);
119
+
120
+ if (nelems == 0 ) {
121
+ return std::make_pair (sycl::event{}, sycl::event{});
122
+ }
123
+
108
124
sycl::event window_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
109
125
sycl::event args_ev =
110
126
dpctl::utils::keep_args_alive (exec_q, {result}, {window_ev});
0 commit comments