Skip to content

Commit c7047bf

Browse files
committed
Simplify nan_to_num call logic
Use std::conditional and value_type_of_t struct to avoid constexpr branches with redundant code
1 parent e89acdd commit c7047bf

File tree

1 file changed

+39
-46
lines changed

1 file changed

+39
-46
lines changed

dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ namespace dpnp::extensions::ufunc
6060

6161
namespace impl
6262
{
63+
64+
template <typename T>
65+
struct value_type_of
66+
{
67+
using type = T;
68+
};
69+
70+
template <typename T>
71+
struct value_type_of<std::complex<T>>
72+
{
73+
using type = T;
74+
};
75+
76+
template <typename T>
77+
using value_type_of_t = typename value_type_of<T>::type;
78+
6379
typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &,
6480
int,
6581
size_t,
@@ -87,30 +103,18 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
87103
py::ssize_t dst_offset,
88104
const std::vector<sycl::event> &depends)
89105
{
90-
sycl::event to_num_ev;
91-
92-
using dpctl::tensor::type_utils::is_complex;
93-
if constexpr (is_complex<T>::value) {
94-
using realT = typename T::value_type;
95-
realT nan_v = py::cast<realT>(py_nan);
96-
realT posinf_v = py::cast<realT>(py_posinf);
97-
realT neginf_v = py::cast<realT>(py_neginf);
98-
99-
using dpnp::kernels::nan_to_num::nan_to_num_impl;
100-
to_num_ev = nan_to_num_impl<T, realT>(
101-
exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
102-
arg_offset, dst_p, dst_offset, depends);
103-
}
104-
else {
105-
T nan_v = py::cast<T>(py_nan);
106-
T posinf_v = py::cast<T>(py_posinf);
107-
T neginf_v = py::cast<T>(py_neginf);
108-
109-
using dpnp::kernels::nan_to_num::nan_to_num_impl;
110-
to_num_ev = nan_to_num_impl<T, T>(
111-
exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
112-
arg_offset, dst_p, dst_offset, depends);
113-
}
106+
using dpctl::tensor::type_utils::is_complex_v;
107+
using scT = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
108+
109+
scT nan_v = py::cast<scT>(py_nan);
110+
scT posinf_v = py::cast<scT>(py_posinf);
111+
scT neginf_v = py::cast<scT>(py_neginf);
112+
113+
using dpnp::kernels::nan_to_num::nan_to_num_impl;
114+
sycl::event to_num_ev = nan_to_num_impl<T, scT>(
115+
exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
116+
arg_offset, dst_p, dst_offset, depends);
117+
114118
return to_num_ev;
115119
}
116120

@@ -134,28 +138,17 @@ sycl::event nan_to_num_contig_call(sycl::queue &exec_q,
134138
char *dst_p,
135139
const std::vector<sycl::event> &depends)
136140
{
137-
sycl::event to_num_contig_ev;
138-
139-
using dpctl::tensor::type_utils::is_complex;
140-
if constexpr (is_complex<T>::value) {
141-
using realT = typename T::value_type;
142-
realT nan_v = py::cast<realT>(py_nan);
143-
realT posinf_v = py::cast<realT>(py_posinf);
144-
realT neginf_v = py::cast<realT>(py_neginf);
145-
146-
using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147-
to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148-
exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149-
}
150-
else {
151-
T nan_v = py::cast<T>(py_nan);
152-
T posinf_v = py::cast<T>(py_posinf);
153-
T neginf_v = py::cast<T>(py_neginf);
154-
155-
using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156-
to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157-
exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158-
}
141+
using dpctl::tensor::type_utils::is_complex_v;
142+
using scT = std::conditional_t<is_complex_v<T>, value_type_of_t<T>, T>;
143+
144+
scT nan_v = py::cast<scT>(py_nan);
145+
scT posinf_v = py::cast<scT>(py_posinf);
146+
scT neginf_v = py::cast<scT>(py_neginf);
147+
148+
using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
149+
sycl::event to_num_contig_ev = nan_to_num_contig_impl<T, scT>(
150+
exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
151+
159152
return to_num_contig_ev;
160153
}
161154

0 commit comments

Comments
 (0)