@@ -60,6 +60,22 @@ namespace dpnp::extensions::ufunc
60
60
61
61
namespace impl
62
62
{
63
+
64
+ template <typename T>
65
+ struct value_type_of
66
+ {
67
+ using type = T;
68
+ };
69
+
70
+ template <typename T>
71
+ struct value_type_of <std::complex<T>>
72
+ {
73
+ using type = T;
74
+ };
75
+
76
+ template <typename T>
77
+ using value_type_of_t = typename value_type_of<T>::type;
78
+
63
79
typedef sycl::event (*nan_to_num_fn_ptr_t )(sycl::queue &,
64
80
int ,
65
81
size_t ,
@@ -87,30 +103,18 @@ sycl::event nan_to_num_call(sycl::queue &exec_q,
87
103
py::ssize_t dst_offset,
88
104
const std::vector<sycl::event> &depends)
89
105
{
90
- sycl::event to_num_ev;
91
-
92
- using dpctl::tensor::type_utils::is_complex;
93
- if constexpr (is_complex<T>::value) {
94
- using realT = typename T::value_type;
95
- realT nan_v = py::cast<realT>(py_nan);
96
- realT posinf_v = py::cast<realT>(py_posinf);
97
- realT neginf_v = py::cast<realT>(py_neginf);
98
-
99
- using dpnp::kernels::nan_to_num::nan_to_num_impl;
100
- to_num_ev = nan_to_num_impl<T, realT>(
101
- exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
102
- arg_offset, dst_p, dst_offset, depends);
103
- }
104
- else {
105
- T nan_v = py::cast<T>(py_nan);
106
- T posinf_v = py::cast<T>(py_posinf);
107
- T neginf_v = py::cast<T>(py_neginf);
108
-
109
- using dpnp::kernels::nan_to_num::nan_to_num_impl;
110
- to_num_ev = nan_to_num_impl<T, T>(
111
- exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
112
- arg_offset, dst_p, dst_offset, depends);
113
- }
106
+ using dpctl::tensor::type_utils::is_complex_v;
107
+ using scT = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
108
+
109
+ scT nan_v = py::cast<scT>(py_nan);
110
+ scT posinf_v = py::cast<scT>(py_posinf);
111
+ scT neginf_v = py::cast<scT>(py_neginf);
112
+
113
+ using dpnp::kernels::nan_to_num::nan_to_num_impl;
114
+ sycl::event to_num_ev = nan_to_num_impl<T, scT>(
115
+ exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p,
116
+ arg_offset, dst_p, dst_offset, depends);
117
+
114
118
return to_num_ev;
115
119
}
116
120
@@ -134,28 +138,17 @@ sycl::event nan_to_num_contig_call(sycl::queue &exec_q,
134
138
char *dst_p,
135
139
const std::vector<sycl::event> &depends)
136
140
{
137
- sycl::event to_num_contig_ev;
138
-
139
- using dpctl::tensor::type_utils::is_complex;
140
- if constexpr (is_complex<T>::value) {
141
- using realT = typename T::value_type;
142
- realT nan_v = py::cast<realT>(py_nan);
143
- realT posinf_v = py::cast<realT>(py_posinf);
144
- realT neginf_v = py::cast<realT>(py_neginf);
145
-
146
- using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
147
- to_num_contig_ev = nan_to_num_contig_impl<T, realT>(
148
- exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
149
- }
150
- else {
151
- T nan_v = py::cast<T>(py_nan);
152
- T posinf_v = py::cast<T>(py_posinf);
153
- T neginf_v = py::cast<T>(py_neginf);
154
-
155
- using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
156
- to_num_contig_ev = nan_to_num_contig_impl<T, T>(
157
- exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
158
- }
141
+ using dpctl::tensor::type_utils::is_complex_v;
142
+ using scT = std::conditional_t <is_complex_v<T>, value_type_of_t <T>, T>;
143
+
144
+ scT nan_v = py::cast<scT>(py_nan);
145
+ scT posinf_v = py::cast<scT>(py_posinf);
146
+ scT neginf_v = py::cast<scT>(py_neginf);
147
+
148
+ using dpnp::kernels::nan_to_num::nan_to_num_contig_impl;
149
+ sycl::event to_num_contig_ev = nan_to_num_contig_impl<T, scT>(
150
+ exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends);
151
+
159
152
return to_num_contig_ev;
160
153
}
161
154
0 commit comments