Skip to content

Commit d229156

Browse files
committed
Add contiguous kernel for nan_to_num
1 parent 782afcc commit d229156

File tree

2 files changed

+163
-4
lines changed

2 files changed

+163
-4
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,54 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
114114
return to_num_ev;
115115
}
116116

117+
typedef sycl::event (*nan_to_num_contig_fn_ptr_t)(
118+
sycl::queue &,
119+
size_t,
120+
const py::object &,
121+
const py::object &,
122+
const py::object &,
123+
const char *,
124+
char *,
125+
const std::vector<sycl::event> &);
126+
127+
template <typename T>
128+
sycl::event nan_to_num_contig_call(sycl::queue &exec_q,
129+
size_t nelems,
130+
const py::object &py_nan,
131+
const py::object &py_posinf,
132+
const py::object &py_neginf,
133+
const char *arg_p,
134+
char *dst_p,
135+
const std::vector<sycl::event> &depends)
136+
{
137+
sycl::event to_num_contig_ev;
138+
139+
using dpctl::tensor::type_utils::is_complex;
140+
if constexpr (is_complex<T>::value) {
141+
using realT = typename T::value_type;
142+
realT nan_v = py::cast<realT>(py_nan);
143+
realT posinf_v = py::cast<realT>(py_posinf);
144+
realT neginf_v = py::cast<realT>(py_neginf);
145+
146+
using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147+
to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148+
exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149+
}
150+
else {
151+
T nan_v = py::cast<T>(py_nan);
152+
T posinf_v = py::cast<T>(py_posinf);
153+
T neginf_v = py::cast<T>(py_neginf);
154+
155+
using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156+
to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157+
exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158+
}
159+
return to_num_contig_ev;
160+
}
161+
117162
namespace td_ns = dpctl::tensor::type_dispatch;
118163
nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types];
164+
nan_to_num_contig_fn_ptr_t nan_to_num_contig_dispatch_vector[td_ns::num_types];
119165

120166
std::pair<sycl::event, sycl::event>
121167
py_nan_to_num(const dpctl::tensor::usm_ndarray &src,
@@ -176,6 +222,37 @@ std::pair<sycl::event, sycl::event>
176222
const char *src_data = src.get_data();
177223
char *dst_data = dst.get_data();
178224

225+
// handle contiguous inputs
226+
bool is_src_c_contig = src.is_c_contiguous();
227+
bool is_src_f_contig = src.is_f_contiguous();
228+
229+
bool is_dst_c_contig = dst.is_c_contiguous();
230+
bool is_dst_f_contig = dst.is_f_contiguous();
231+
232+
bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
233+
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
234+
235+
if (both_c_contig || both_f_contig) {
236+
auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
237+
238+
if (contig_fn == nullptr) {
239+
throw std::runtime_error(
240+
"Contiguous implementation is missing for src_typeid=" +
241+
std::to_string(src_typeid));
242+
}
243+
244+
auto comp_ev = contig_fn(q, nelems, py_nan, py_posinf, py_neginf,
245+
src_data, dst_data, depends);
246+
sycl::event ht_ev =
247+
dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
248+
249+
return std::make_pair(ht_ev, comp_ev);
250+
}
251+
252+
// simplify iteration space
253+
// if 1d with strides 1 - input is contig
254+
// dispatch to strided
255+
179256
auto const &src_strides = src.get_strides_vector();
180257
auto const &dst_strides = dst.get_strides_vector();
181258

@@ -195,6 +272,30 @@ std::pair<sycl::event, sycl::event>
195272
simplified_shape, simplified_src_strides, simplified_dst_strides,
196273
src_offset, dst_offset);
197274

275+
if (nd == 1 && simplified_src_strides[0] == 1 &&
276+
simplified_dst_strides[0] == 1) {
277+
// Special case of contiguous data
278+
auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
279+
280+
if (contig_fn == nullptr) {
281+
throw std::runtime_error(
282+
"Contiguous implementation is missing for src_typeid=" +
283+
std::to_string(src_typeid));
284+
}
285+
286+
int src_elem_size = src.get_elemsize();
287+
int dst_elem_size = dst.get_elemsize();
288+
auto comp_ev =
289+
contig_fn(q, nelems, py_nan, py_posinf, py_neginf,
290+
src_data + src_elem_size * src_offset,
291+
dst_data + dst_elem_size * dst_offset, depends);
292+
293+
sycl::event ht_ev =
294+
dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev});
295+
296+
return std::make_pair(ht_ev, comp_ev);
297+
}
298+
198299
auto fn = nan_to_num_dispatch_vector[src_typeid];
199300

200301
if (fn == nullptr) {
@@ -277,20 +378,41 @@ struct NanToNumFactory
277378
}
278379
};
279380

280-
void populate_nan_to_num_dispatch_vector(void)
381+
template <typename fnT, typename T>
382+
struct NanToNumContigFactory
383+
{
384+
fnT get()
385+
{
386+
if constexpr (std::is_same_v<typename NanToNumOutputType<T>::value_type,
387+
void>) {
388+
return nullptr;
389+
}
390+
else {
391+
using ::dpnp::extensions::ufunc::impl::nan_to_num_contig_call;
392+
return nan_to_num_contig_call<T>;
393+
}
394+
}
395+
};
396+
397+
void populate_nan_to_num_dispatch_vectors(void)
281398
{
282399
using namespace td_ns;
283400

284-
DispatchVectorBuilder<nan_to_num_fn_ptr_t, NanToNumFactory, num_types> dvb;
285-
dvb.populate_dispatch_vector(nan_to_num_dispatch_vector);
401+
DispatchVectorBuilder<nan_to_num_fn_ptr_t, NanToNumFactory, num_types> dvb1;
402+
dvb1.populate_dispatch_vector(nan_to_num_dispatch_vector);
403+
404+
DispatchVectorBuilder<nan_to_num_contig_fn_ptr_t, NanToNumContigFactory,
405+
num_types>
406+
dvb2;
407+
dvb2.populate_dispatch_vector(nan_to_num_contig_dispatch_vector);
286408
}
287409

288410
} // namespace impl
289411

290412
void init_nan_to_num(py::module_ m)
291413
{
292414
{
293-
impl::populate_nan_to_num_dispatch_vector();
415+
impl::populate_nan_to_num_dispatch_vectors();
294416

295417
using impl::py_nan_to_num;
296418
m.def("_nan_to_num", &py_nan_to_num, "", py::arg("src"),

dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,41 @@ sycl::event nan_to_num_impl(sycl::queue &q,
127127
return comp_ev;
128128
}
129129

130+
template <typename T>
131+
class NanToNumContigKernel;
132+
133+
template <typename T, typename scT>
134+
sycl::event nan_to_num_contig_impl(sycl::queue &q,
135+
size_t nelems,
136+
const scT nan,
137+
const scT posinf,
138+
const scT neginf,
139+
const char *in_cp,
140+
char *out_cp,
141+
const std::vector<sycl::event> &depends)
142+
{
143+
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
144+
145+
const T *in_tp = reinterpret_cast<const T *>(in_cp);
146+
T *out_tp = reinterpret_cast<T *>(out_cp);
147+
148+
using dpctl::tensor::offset_utils::NoOpIndexer;
149+
using InOutIndexerT =
150+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<NoOpIndexer,
151+
NoOpIndexer>;
152+
constexpr NoOpIndexer in_indexer{};
153+
constexpr NoOpIndexer out_indexer{};
154+
constexpr InOutIndexerT indexer{in_indexer, out_indexer};
155+
156+
sycl::event comp_ev = q.submit([&](sycl::handler &cgh) {
157+
cgh.depends_on(depends);
158+
159+
using KernelName = NanToNumContigKernel<T>;
160+
cgh.parallel_for<KernelName>(
161+
{nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
162+
in_tp, out_tp, indexer, nan, posinf, neginf));
163+
});
164+
return comp_ev;
165+
}
166+
130167
} // namespace dpnp::kernels::nan_to_num

0 commit comments

Comments
 (0)