@@ -114,8 +114,54 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
114
114
return to_num_ev;
115
115
}
116
116
117
+ typedef sycl::event (*nan_to_num_contig_fn_ptr_t )(
118
+ sycl::queue &,
119
+ size_t ,
120
+ const py::object &,
121
+ const py::object &,
122
+ const py::object &,
123
+ const char *,
124
+ char *,
125
+ const std::vector<sycl::event> &);
126
+
127
+ template <typename T>
128
+ sycl::event nan_to_num_contig_call (sycl::queue &exec_q,
129
+ size_t nelems,
130
+ const py::object &py_nan,
131
+ const py::object &py_posinf,
132
+ const py::object &py_neginf,
133
+ const char *arg_p,
134
+ char *dst_p,
135
+ const std::vector<sycl::event> &depends)
136
+ {
137
+ sycl::event to_num_contig_ev;
138
+
139
+ using dpctl::tensor::type_utils::is_complex;
140
+ if constexpr (is_complex<T>::value) {
141
+ using realT = typename T::value_type;
142
+ realT nan_v = py::cast<realT>(py_nan);
143
+ realT posinf_v = py::cast<realT>(py_posinf);
144
+ realT neginf_v = py::cast<realT>(py_neginf);
145
+
146
+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147
+ to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148
+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149
+ }
150
+ else {
151
+ T nan_v = py::cast<T>(py_nan);
152
+ T posinf_v = py::cast<T>(py_posinf);
153
+ T neginf_v = py::cast<T>(py_neginf);
154
+
155
+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156
+ to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157
+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158
+ }
159
+ return to_num_contig_ev;
160
+ }
161
+
117
162
namespace td_ns = dpctl::tensor::type_dispatch;
118
163
nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types];
164
+ nan_to_num_contig_fn_ptr_t nan_to_num_contig_dispatch_vector[td_ns::num_types];
119
165
120
166
std::pair<sycl::event, sycl::event>
121
167
py_nan_to_num (const dpctl::tensor::usm_ndarray &src,
@@ -176,6 +222,37 @@ std::pair<sycl::event, sycl::event>
176
222
const char *src_data = src.get_data ();
177
223
char *dst_data = dst.get_data ();
178
224
225
+ // handle contiguous inputs
226
+ bool is_src_c_contig = src.is_c_contiguous ();
227
+ bool is_src_f_contig = src.is_f_contiguous ();
228
+
229
+ bool is_dst_c_contig = dst.is_c_contiguous ();
230
+ bool is_dst_f_contig = dst.is_f_contiguous ();
231
+
232
+ bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
233
+ bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
234
+
235
+ if (both_c_contig || both_f_contig) {
236
+ auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
237
+
238
+ if (contig_fn == nullptr ) {
239
+ throw std::runtime_error (
240
+ " Contiguous implementation is missing for src_typeid=" +
241
+ std::to_string (src_typeid));
242
+ }
243
+
244
+ auto comp_ev = contig_fn (q, nelems, py_nan, py_posinf, py_neginf,
245
+ src_data, dst_data, depends);
246
+ sycl::event ht_ev =
247
+ dpctl::utils::keep_args_alive (q, {src, dst}, {comp_ev});
248
+
249
+ return std::make_pair (ht_ev, comp_ev);
250
+ }
251
+
252
+ // simplify iteration space
253
+ // if 1d with strides 1 - input is contig
254
+ // dispatch to strided
255
+
179
256
auto const &src_strides = src.get_strides_vector ();
180
257
auto const &dst_strides = dst.get_strides_vector ();
181
258
@@ -195,6 +272,30 @@ std::pair<sycl::event, sycl::event>
195
272
simplified_shape, simplified_src_strides, simplified_dst_strides,
196
273
src_offset, dst_offset);
197
274
275
+ if (nd == 1 && simplified_src_strides[0 ] == 1 &&
276
+ simplified_dst_strides[0 ] == 1 ) {
277
+ // Special case of contiguous data
278
+ auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
279
+
280
+ if (contig_fn == nullptr ) {
281
+ throw std::runtime_error (
282
+ " Contiguous implementation is missing for src_typeid=" +
283
+ std::to_string (src_typeid));
284
+ }
285
+
286
+ int src_elem_size = src.get_elemsize ();
287
+ int dst_elem_size = dst.get_elemsize ();
288
+ auto comp_ev =
289
+ contig_fn (q, nelems, py_nan, py_posinf, py_neginf,
290
+ src_data + src_elem_size * src_offset,
291
+ dst_data + dst_elem_size * dst_offset, depends);
292
+
293
+ sycl::event ht_ev =
294
+ dpctl::utils::keep_args_alive (q, {src, dst}, {comp_ev});
295
+
296
+ return std::make_pair (ht_ev, comp_ev);
297
+ }
298
+
198
299
auto fn = nan_to_num_dispatch_vector[src_typeid];
199
300
200
301
if (fn == nullptr ) {
@@ -277,20 +378,41 @@ struct NanToNumFactory
277
378
}
278
379
};
279
380
280
- void populate_nan_to_num_dispatch_vector (void )
381
+ template <typename fnT, typename T>
382
+ struct NanToNumContigFactory
383
+ {
384
+ fnT get ()
385
+ {
386
+ if constexpr (std::is_same_v<typename NanToNumOutputType<T>::value_type,
387
+ void >) {
388
+ return nullptr ;
389
+ }
390
+ else {
391
+ using ::dpnp::extensions::ufunc::impl::nan_to_num_contig_call;
392
+ return nan_to_num_contig_call<T>;
393
+ }
394
+ }
395
+ };
396
+
397
+ void populate_nan_to_num_dispatch_vectors (void )
281
398
{
282
399
using namespace td_ns ;
283
400
284
- DispatchVectorBuilder<nan_to_num_fn_ptr_t , NanToNumFactory, num_types> dvb;
285
- dvb.populate_dispatch_vector (nan_to_num_dispatch_vector);
401
+ DispatchVectorBuilder<nan_to_num_fn_ptr_t , NanToNumFactory, num_types> dvb1;
402
+ dvb1.populate_dispatch_vector (nan_to_num_dispatch_vector);
403
+
404
+ DispatchVectorBuilder<nan_to_num_contig_fn_ptr_t , NanToNumContigFactory,
405
+ num_types>
406
+ dvb2;
407
+ dvb2.populate_dispatch_vector (nan_to_num_contig_dispatch_vector);
286
408
}
287
409
288
410
} // namespace impl
289
411
290
412
void init_nan_to_num (py::module_ m)
291
413
{
292
414
{
293
- impl::populate_nan_to_num_dispatch_vector ();
415
+ impl::populate_nan_to_num_dispatch_vectors ();
294
416
295
417
using impl::py_nan_to_num;
296
418
m.def (" _nan_to_num" , &py_nan_to_num, " " , py::arg (" src" ),
0 commit comments