Skip to content

Commit 5d19afb

Browse files
committed
Add subgroup load and store based implementation for nan_to_num kernel
1 parent 54dfaf5 commit 5d19afb

File tree

1 file changed

+147
-30
lines changed

1 file changed

+147
-30
lines changed

dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp

Lines changed: 147 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131

3232
#include <sycl/sycl.hpp>
3333
// dpctl tensor headers
34+
#include "kernels/alignment.hpp"
3435
#include "kernels/dpctl_tensor_types.hpp"
3536
#include "utils/offset_utils.hpp"
37+
#include "utils/sycl_utils.hpp"
3638
#include "utils/type_utils.hpp"
3739

3840
namespace dpnp::kernels::nan_to_num
@@ -49,6 +51,14 @@ inline T to_num(const T v, const T nan, const T posinf, const T neginf)
4951
template <typename T, typename scT, typename InOutIndexerT>
5052
struct NanToNumFunctor
5153
{
54+
private:
55+
const T *inp_ = nullptr;
56+
T *out_ = nullptr;
57+
const InOutIndexerT inp_out_indexer_;
58+
const scT nan_;
59+
const scT posinf_;
60+
const scT neginf_;
61+
5262
public:
5363
NanToNumFunctor(const T *inp,
5464
T *out,
@@ -80,18 +90,104 @@ struct NanToNumFunctor
8090
out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_);
8191
}
8292
}
93+
};
8394

95+
template <typename T,
96+
typename scT,
97+
std::uint8_t vec_sz = 4u,
98+
std::uint8_t n_vecs = 2u,
99+
bool enable_sg_loadstore = true>
100+
struct NanToNumContigFunctor
101+
{
84102
private:
85-
const T *inp_ = nullptr;
103+
const T *in_ = nullptr;
86104
T *out_ = nullptr;
87-
const InOutIndexerT inp_out_indexer_;
105+
std::size_t nelems_;
88106
const scT nan_;
89107
const scT posinf_;
90108
const scT neginf_;
91-
};
92109

93-
template <typename T>
94-
class NanToNumKernel;
110+
public:
111+
NanToNumContigFunctor(const T *in,
112+
T *out,
113+
const std::size_t n_elems,
114+
const scT nan,
115+
const scT posinf,
116+
const scT neginf)
117+
: in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf),
118+
neginf_(neginf)
119+
{
120+
}
121+
122+
void operator()(sycl::nd_item<1> ndit) const
123+
{
124+
constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
125+
/* Each work-item processes vec_sz elements, contiguous in memory */
126+
/* NOTE: work-group size must be divisible by sub-group size */
127+
128+
using dpctl::tensor::type_utils::is_complex_v;
129+
if constexpr (enable_sg_loadstore && !is_complex_v<T>) {
130+
auto sg = ndit.get_sub_group();
131+
const std::uint16_t sgSize = sg.get_max_local_range()[0];
132+
const std::size_t base =
133+
elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) +
134+
sg.get_group_id()[0] * sgSize);
135+
136+
if (base + elems_per_wi * sgSize < nelems_) {
137+
using dpctl::tensor::sycl_utils::sub_group_load;
138+
using dpctl::tensor::sycl_utils::sub_group_store;
139+
#pragma unroll
140+
for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) {
141+
const std::size_t offset = base + it * sgSize;
142+
auto in_multi_ptr = sycl::address_space_cast<
143+
sycl::access::address_space::global_space,
144+
sycl::access::decorated::yes>(&in_[offset]);
145+
auto out_multi_ptr = sycl::address_space_cast<
146+
sycl::access::address_space::global_space,
147+
sycl::access::decorated::yes>(&out_[offset]);
148+
149+
sycl::vec<T, vec_sz> arg_vec =
150+
sub_group_load<vec_sz>(sg, in_multi_ptr);
151+
#pragma unroll
152+
for (std::uint32_t k = 0; k < vec_sz; ++k) {
153+
arg_vec[k] = to_num(arg_vec[k], nan_, posinf_, neginf_);
154+
}
155+
sub_group_store<vec_sz>(sg, arg_vec, out_multi_ptr);
156+
}
157+
}
158+
else {
159+
const std::size_t lane_id = sg.get_local_id()[0];
160+
for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) {
161+
out_[k] = to_num(in_[k], nan_, posinf_, neginf_);
162+
}
163+
}
164+
}
165+
else {
166+
const std::uint16_t sgSize =
167+
ndit.get_sub_group().get_local_range()[0];
168+
const std::size_t gid = ndit.get_global_linear_id();
169+
const std::uint16_t elems_per_sg = sgSize * elems_per_wi;
170+
171+
const std::size_t start =
172+
(gid / sgSize) * (elems_per_sg - sgSize) + gid;
173+
const std::size_t end = std::min(nelems_, start + elems_per_sg);
174+
for (std::size_t offset = start; offset < end; offset += sgSize) {
175+
if constexpr (is_complex_v<T>) {
176+
using realT = typename T::value_type;
177+
static_assert(std::is_same_v<realT, scT>);
178+
179+
T z = in_[offset];
180+
realT x = to_num(z.real(), nan_, posinf_, neginf_);
181+
realT y = to_num(z.imag(), nan_, posinf_, neginf_);
182+
out_[offset] = T{x, y};
183+
}
184+
else {
185+
out_[offset] = to_num(in_[offset], nan_, posinf_, neginf_);
186+
}
187+
}
188+
}
189+
}
190+
};
95191

96192
template <typename T, typename scT>
97193
sycl::event nan_to_num_impl(sycl::queue &q,
@@ -119,48 +215,69 @@ sycl::event nan_to_num_impl(sycl::queue &q,
119215
sycl::event comp_ev = q.submit([&](sycl::handler &cgh) {
120216
cgh.depends_on(depends);
121217

122-
using KernelName = NanToNumKernel<T>;
123-
cgh.parallel_for<KernelName>(
124-
{nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
125-
in_tp, out_tp, indexer, nan, posinf, neginf));
218+
using NanToNumFunc = NanToNumFunctor<T, scT, InOutIndexerT>;
219+
cgh.parallel_for<NanToNumFunc>(
220+
{nelems},
221+
NanToNumFunc(in_tp, out_tp, indexer, nan, posinf, neginf));
126222
});
127223
return comp_ev;
128224
}
129225

130-
template <typename T>
131-
class NanToNumContigKernel;
132-
133-
template <typename T, typename scT>
134-
sycl::event nan_to_num_contig_impl(sycl::queue &q,
135-
const size_t nelems,
226+
template <typename T,
227+
typename scT,
228+
std::uint8_t vec_sz = 4u,
229+
std::uint8_t n_vecs = 2u>
230+
sycl::event nan_to_num_contig_impl(sycl::queue &exec_q,
231+
std::size_t nelems,
136232
const scT nan,
137233
const scT posinf,
138234
const scT neginf,
139235
const char *in_cp,
140236
char *out_cp,
141-
const std::vector<sycl::event> &depends)
237+
const std::vector<sycl::event> &depends = {})
142238
{
143-
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
239+
constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz;
240+
const std::size_t n_work_items_needed = nelems / elems_per_wi;
241+
const std::size_t empirical_threshold = std::size_t(1) << 21;
242+
const std::size_t lws = (n_work_items_needed <= empirical_threshold)
243+
? std::size_t(128)
244+
: std::size_t(256);
245+
246+
const std::size_t n_groups =
247+
((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi));
248+
const auto gws_range = sycl::range<1>(n_groups * lws);
249+
const auto lws_range = sycl::range<1>(lws);
144250

145251
const T *in_tp = reinterpret_cast<const T *>(in_cp);
146252
T *out_tp = reinterpret_cast<T *>(out_cp);
147253

148-
using dpctl::tensor::offset_utils::NoOpIndexer;
149-
using InOutIndexerT =
150-
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<NoOpIndexer,
151-
NoOpIndexer>;
152-
constexpr NoOpIndexer in_indexer{};
153-
constexpr NoOpIndexer out_indexer{};
154-
constexpr InOutIndexerT indexer{in_indexer, out_indexer};
155-
156-
sycl::event comp_ev = q.submit([&](sycl::handler &cgh) {
254+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
157255
cgh.depends_on(depends);
158256

159-
using KernelName = NanToNumContigKernel<T>;
160-
cgh.parallel_for<KernelName>(
161-
{nelems}, NanToNumFunctor<T, scT, InOutIndexerT>(
162-
in_tp, out_tp, indexer, nan, posinf, neginf));
257+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
258+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
259+
if (is_aligned<required_alignment>(in_tp) &&
260+
is_aligned<required_alignment>(out_tp))
261+
{
262+
constexpr bool enable_sg_loadstore = true;
263+
using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
264+
enable_sg_loadstore>;
265+
266+
cgh.parallel_for<NanToNumFunc>(
267+
sycl::nd_range<1>(gws_range, lws_range),
268+
NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));
269+
}
270+
else {
271+
constexpr bool disable_sg_loadstore = false;
272+
using NanToNumFunc = NanToNumContigFunctor<T, scT, vec_sz, n_vecs,
273+
disable_sg_loadstore>;
274+
275+
cgh.parallel_for<NanToNumFunc>(
276+
sycl::nd_range<1>(gws_range, lws_range),
277+
NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf));
278+
}
163279
});
280+
164281
return comp_ev;
165282
}
166283

0 commit comments

Comments
 (0)