31
31
32
32
#include < sycl/sycl.hpp>
33
33
// dpctl tensor headers
34
+ #include " kernels/alignment.hpp"
34
35
#include " kernels/dpctl_tensor_types.hpp"
35
36
#include " utils/offset_utils.hpp"
37
+ #include " utils/sycl_utils.hpp"
36
38
#include " utils/type_utils.hpp"
37
39
38
40
namespace dpnp ::kernels::nan_to_num
@@ -49,6 +51,14 @@ inline T to_num(const T v, const T nan, const T posinf, const T neginf)
49
51
template <typename T, typename scT, typename InOutIndexerT>
50
52
struct NanToNumFunctor
51
53
{
54
+ private:
55
+ const T *inp_ = nullptr ;
56
+ T *out_ = nullptr ;
57
+ const InOutIndexerT inp_out_indexer_;
58
+ const scT nan_;
59
+ const scT posinf_;
60
+ const scT neginf_;
61
+
52
62
public:
53
63
NanToNumFunctor (const T *inp,
54
64
T *out,
@@ -80,18 +90,104 @@ struct NanToNumFunctor
80
90
out_[out_offset] = to_num (inp_[inp_offset], nan_, posinf_, neginf_);
81
91
}
82
92
}
93
+ };
83
94
95
+ template <typename T,
96
+ typename scT,
97
+ std::uint8_t vec_sz = 4u ,
98
+ std::uint8_t n_vecs = 2u ,
99
+ bool enable_sg_loadstore = true >
100
+ struct NanToNumContigFunctor
101
+ {
84
102
private:
85
- const T *inp_ = nullptr ;
103
+ const T *in_ = nullptr ;
86
104
T *out_ = nullptr ;
87
- const InOutIndexerT inp_out_indexer_ ;
105
+ std:: size_t nelems_ ;
88
106
const scT nan_;
89
107
const scT posinf_;
90
108
const scT neginf_;
91
- };
92
109
93
- template <typename T>
94
- class NanToNumKernel ;
110
+ public:
111
+ NanToNumContigFunctor (const T *in,
112
+ T *out,
113
+ const std::size_t n_elems,
114
+ const scT nan,
115
+ const scT posinf,
116
+ const scT neginf)
117
+ : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf),
118
+ neginf_ (neginf)
119
+ {
120
+ }
121
+
122
+ void operator ()(sycl::nd_item<1 > ndit) const
123
+ {
124
+ constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
125
+ /* Each work-item processes vec_sz elements, contiguous in memory */
126
+ /* NOTE: work-group size must be divisible by sub-group size */
127
+
128
+ using dpctl::tensor::type_utils::is_complex_v;
129
+ if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
130
+ auto sg = ndit.get_sub_group ();
131
+ const std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
132
+ const std::size_t base =
133
+ elems_per_wi * (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
134
+ sg.get_group_id ()[0 ] * sgSize);
135
+
136
+ if (base + elems_per_wi * sgSize < nelems_) {
137
+ using dpctl::tensor::sycl_utils::sub_group_load;
138
+ using dpctl::tensor::sycl_utils::sub_group_store;
139
+ #pragma unroll
140
+ for (std::uint8_t it = 0 ; it < elems_per_wi; it += vec_sz) {
141
+ const std::size_t offset = base + it * sgSize;
142
+ auto in_multi_ptr = sycl::address_space_cast<
143
+ sycl::access::address_space::global_space,
144
+ sycl::access::decorated::yes>(&in_[offset]);
145
+ auto out_multi_ptr = sycl::address_space_cast<
146
+ sycl::access::address_space::global_space,
147
+ sycl::access::decorated::yes>(&out_[offset]);
148
+
149
+ sycl::vec<T, vec_sz> arg_vec =
150
+ sub_group_load<vec_sz>(sg, in_multi_ptr);
151
+ #pragma unroll
152
+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
153
+ arg_vec[k] = to_num (arg_vec[k], nan_, posinf_, neginf_);
154
+ }
155
+ sub_group_store<vec_sz>(sg, arg_vec, out_multi_ptr);
156
+ }
157
+ }
158
+ else {
159
+ const std::size_t lane_id = sg.get_local_id ()[0 ];
160
+ for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
161
+ out_[k] = to_num (in_[k], nan_, posinf_, neginf_);
162
+ }
163
+ }
164
+ }
165
+ else {
166
+ const std::uint16_t sgSize =
167
+ ndit.get_sub_group ().get_local_range ()[0 ];
168
+ const std::size_t gid = ndit.get_global_linear_id ();
169
+ const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
170
+
171
+ const std::size_t start =
172
+ (gid / sgSize) * (elems_per_sg - sgSize) + gid;
173
+ const std::size_t end = std::min (nelems_, start + elems_per_sg);
174
+ for (std::size_t offset = start; offset < end; offset += sgSize) {
175
+ if constexpr (is_complex_v<T>) {
176
+ using realT = typename T::value_type;
177
+ static_assert (std::is_same_v<realT, scT>);
178
+
179
+ T z = in_[offset];
180
+ realT x = to_num (z.real (), nan_, posinf_, neginf_);
181
+ realT y = to_num (z.imag (), nan_, posinf_, neginf_);
182
+ out_[offset] = T{x, y};
183
+ }
184
+ else {
185
+ out_[offset] = to_num (in_[offset], nan_, posinf_, neginf_);
186
+ }
187
+ }
188
+ }
189
+ }
190
+ };
95
191
96
192
template <typename T, typename scT>
97
193
sycl::event nan_to_num_impl (sycl::queue &q,
@@ -119,48 +215,69 @@ sycl::event nan_to_num_impl(sycl::queue &q,
119
215
sycl::event comp_ev = q.submit ([&](sycl::handler &cgh) {
120
216
cgh.depends_on (depends);
121
217
122
- using KernelName = NanToNumKernel<T >;
123
- cgh.parallel_for <KernelName >(
124
- {nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
125
- in_tp, out_tp, indexer, nan, posinf, neginf));
218
+ using NanToNumFunc = NanToNumFunctor<T, scT, InOutIndexerT >;
219
+ cgh.parallel_for <NanToNumFunc >(
220
+ {nelems},
221
+ NanToNumFunc ( in_tp, out_tp, indexer, nan, posinf, neginf));
126
222
});
127
223
return comp_ev;
128
224
}
129
225
130
- template <typename T>
131
- class NanToNumContigKernel ;
132
-
133
- template < typename T, typename scT >
134
- sycl::event nan_to_num_contig_impl (sycl::queue &q ,
135
- const size_t nelems,
226
+ template <typename T,
227
+ typename scT,
228
+ std:: uint8_t vec_sz = 4u ,
229
+ std:: uint8_t n_vecs = 2u >
230
+ sycl::event nan_to_num_contig_impl (sycl::queue &exec_q ,
231
+ std:: size_t nelems,
136
232
const scT nan,
137
233
const scT posinf,
138
234
const scT neginf,
139
235
const char *in_cp,
140
236
char *out_cp,
141
- const std::vector<sycl::event> &depends)
237
+ const std::vector<sycl::event> &depends = {} )
142
238
{
143
- dpctl::tensor::type_utils::validate_type_for_device<T>(q);
239
+ constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
240
+ const std::size_t n_work_items_needed = nelems / elems_per_wi;
241
+ const std::size_t empirical_threshold = std::size_t (1 ) << 21 ;
242
+ const std::size_t lws = (n_work_items_needed <= empirical_threshold)
243
+ ? std::size_t (128 )
244
+ : std::size_t (256 );
245
+
246
+ const std::size_t n_groups =
247
+ ((nelems + lws * elems_per_wi - 1 ) / (lws * elems_per_wi));
248
+ const auto gws_range = sycl::range<1 >(n_groups * lws);
249
+ const auto lws_range = sycl::range<1 >(lws);
144
250
145
251
const T *in_tp = reinterpret_cast <const T *>(in_cp);
146
252
T *out_tp = reinterpret_cast <T *>(out_cp);
147
253
148
- using dpctl::tensor::offset_utils::NoOpIndexer;
149
- using InOutIndexerT =
150
- dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<NoOpIndexer,
151
- NoOpIndexer>;
152
- constexpr NoOpIndexer in_indexer{};
153
- constexpr NoOpIndexer out_indexer{};
154
- constexpr InOutIndexerT indexer{in_indexer, out_indexer};
155
-
156
- sycl::event comp_ev = q.submit ([&](sycl::handler &cgh) {
254
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
157
255
cgh.depends_on (depends);
158
256
159
- using KernelName = NanToNumContigKernel<T>;
160
- cgh.parallel_for <KernelName>(
161
- {nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
162
- in_tp, out_tp, indexer, nan, posinf, neginf));
257
+ using dpctl::tensor::kernels::alignment_utils::is_aligned;
258
+ using dpctl::tensor::kernels::alignment_utils::required_alignment;
259
+ if (is_aligned<required_alignment>(in_tp) &&
260
+ is_aligned<required_alignment>(out_tp))
261
+ {
262
+ constexpr bool enable_sg_loadstore = true ;
263
+ using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
264
+ enable_sg_loadstore>;
265
+
266
+ cgh.parallel_for <NanToNumFunc>(
267
+ sycl::nd_range<1 >(gws_range, lws_range),
268
+ NanToNumFunc (in_tp, out_tp, nelems, nan, posinf, neginf));
269
+ }
270
+ else {
271
+ constexpr bool disable_sg_loadstore = false ;
272
+ using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
273
+ disable_sg_loadstore>;
274
+
275
+ cgh.parallel_for <NanToNumFunc>(
276
+ sycl::nd_range<1 >(gws_range, lws_range),
277
+ NanToNumFunc (in_tp, out_tp, nelems, nan, posinf, neginf));
278
+ }
163
279
});
280
+
164
281
return comp_ev;
165
282
}
166
283
0 commit comments