Skip to content

Commit f623209

Browse files
Clean-ups in binary/unary contig call operator
1 parent 9240d5b commit f623209

File tree

1 file changed

+50
-50
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+50
-50
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct UnaryContigFunctor
7070

7171
void operator()(sycl::nd_item<1> ndit) const
7272
{
73+
constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
7374
UnaryOperatorT op{};
7475
/* Each work-item processes vec_sz elements, contiguous in memory */
7576
/* NOTE: work-group size must be divisible by sub-group size */
@@ -80,14 +81,15 @@ struct UnaryContigFunctor
8081
constexpr resT const_val = UnaryOperatorT::constant_value;
8182

8283
auto sg = ndit.get_sub_group();
83-
std::uint8_t sgSize = sg.get_max_local_range()[0];
84-
size_t base = n_vecs * vec_sz *
84+
std::uint32_t sgSize = sg.get_max_local_range()[0];
85+
86+
size_t base = static_cast<size_t>(elems_per_wi) *
8587
(ndit.get_group(0) * ndit.get_local_range(0) +
8688
sg.get_group_id()[0] * sgSize);
87-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
89+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
8890
sycl::vec<resT, vec_sz> res_vec(const_val);
8991
#pragma unroll
90-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
92+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
9193
size_t offset = base + static_cast<size_t>(it) *
9294
static_cast<size_t>(sgSize);
9395
auto out_multi_ptr = sycl::address_space_cast<
@@ -98,9 +100,8 @@ struct UnaryContigFunctor
98100
}
99101
}
100102
else {
101-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
102-
k += sgSize)
103-
{
103+
const size_t lane_id = sg.get_local_id()[0];
104+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
104105
out[k] = const_val;
105106
}
106107
}
@@ -110,15 +111,16 @@ struct UnaryContigFunctor
110111
UnaryOperatorT::supports_vec::value)
111112
{
112113
auto sg = ndit.get_sub_group();
113-
std::uint16_t sgSize = sg.get_max_local_range()[0];
114-
size_t base = n_vecs * vec_sz *
114+
std::uint32_t sgSize = sg.get_max_local_range()[0];
115+
116+
size_t base = static_cast<size_t>(elems_per_wi) *
115117
(ndit.get_group(0) * ndit.get_local_range(0) +
116118
sg.get_group_id()[0] * sgSize);
117-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
119+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
118120
sycl::vec<argT, vec_sz> x;
119121

120122
#pragma unroll
121-
for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
123+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
122124
size_t offset = base + static_cast<size_t>(it) *
123125
static_cast<size_t>(sgSize);
124126
auto in_multi_ptr = sycl::address_space_cast<
@@ -134,9 +136,8 @@ struct UnaryContigFunctor
134136
}
135137
}
136138
else {
137-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
138-
k += sgSize)
139-
{
139+
const size_t lane_id = sg.get_local_id()[0];
140+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
140141
// scalar call
141142
out[k] = op(in[k]);
142143
}
@@ -149,16 +150,16 @@ struct UnaryContigFunctor
149150
// default: use scalar-value function
150151

151152
auto sg = ndit.get_sub_group();
152-
std::uint8_t sgSize = sg.get_max_local_range()[0];
153-
size_t base = n_vecs * vec_sz *
153+
std::uint32_t sgSize = sg.get_max_local_range()[0];
154+
size_t base = static_cast<size_t>(elems_per_wi) *
154155
(ndit.get_group(0) * ndit.get_local_range(0) +
155156
sg.get_group_id()[0] * sgSize);
156157

157-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
158+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
158159
sycl::vec<argT, vec_sz> arg_vec;
159160

160161
#pragma unroll
161-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
162+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
162163
size_t offset = base + static_cast<size_t>(it) *
163164
static_cast<size_t>(sgSize);
164165
auto in_multi_ptr = sycl::address_space_cast<
@@ -170,16 +171,15 @@ struct UnaryContigFunctor
170171

171172
arg_vec = sg.load<vec_sz>(in_multi_ptr);
172173
#pragma unroll
173-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
174+
for (std::uint32_t k = 0; k < vec_sz; ++k) {
174175
arg_vec[k] = op(arg_vec[k]);
175176
}
176177
sg.store<vec_sz>(out_multi_ptr, arg_vec);
177178
}
178179
}
179180
else {
180-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
181-
k += sgSize)
182-
{
181+
const size_t lane_id = sg.get_local_id()[0];
182+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
183183
out[k] = op(in[k]);
184184
}
185185
}
@@ -190,17 +190,17 @@ struct UnaryContigFunctor
190190
// default: use scalar-value function
191191

192192
auto sg = ndit.get_sub_group();
193-
std::uint8_t sgSize = sg.get_max_local_range()[0];
194-
size_t base = n_vecs * vec_sz *
193+
std::uint32_t sgSize = sg.get_max_local_range()[0];
194+
size_t base = static_cast<size_t>(elems_per_wi) *
195195
(ndit.get_group(0) * ndit.get_local_range(0) +
196196
sg.get_group_id()[0] * sgSize);
197197

198-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
198+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
199199
sycl::vec<argT, vec_sz> arg_vec;
200200
sycl::vec<resT, vec_sz> res_vec;
201201

202202
#pragma unroll
203-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
203+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
204204
size_t offset = base + static_cast<size_t>(it) *
205205
static_cast<size_t>(sgSize);
206206
auto in_multi_ptr = sycl::address_space_cast<
@@ -212,27 +212,27 @@ struct UnaryContigFunctor
212212

213213
arg_vec = sg.load<vec_sz>(in_multi_ptr);
214214
#pragma unroll
215-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
215+
for (std::uint32_t k = 0; k < vec_sz; ++k) {
216216
res_vec[k] = op(arg_vec[k]);
217217
}
218218
sg.store<vec_sz>(out_multi_ptr, res_vec);
219219
}
220220
}
221221
else {
222-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
223-
k += sgSize)
224-
{
222+
const size_t lane_id = sg.get_local_id()[0];
223+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
225224
out[k] = op(in[k]);
226225
}
227226
}
228227
}
229228
else {
230-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
229+
size_t sgSize = ndit.get_sub_group().get_local_range()[0];
231230
size_t base = ndit.get_global_linear_id();
231+
const size_t elems_per_sg = sgSize * elems_per_wi;
232232

233-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
233+
base = (base / sgSize) * elems_per_sg + (base % sgSize);
234234
for (size_t offset = base;
235-
offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz));
235+
offset < std::min(nelems_, base + elems_per_sg);
236236
offset += sgSize)
237237
{
238238
out[offset] = op(in[offset]);
@@ -392,6 +392,7 @@ struct BinaryContigFunctor
392392

393393
void operator()(sycl::nd_item<1> ndit) const
394394
{
395+
constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
395396
BinaryOperatorT op{};
396397
/* Each work-item processes vec_sz elements, contiguous in memory */
397398
/* NOTE: work-group size must be divisible by sub-group size */
@@ -401,19 +402,19 @@ struct BinaryContigFunctor
401402
BinaryOperatorT::supports_vec::value)
402403
{
403404
auto sg = ndit.get_sub_group();
404-
std::uint8_t sgSize = sg.get_max_local_range()[0];
405+
std::uint16_t sgSize = sg.get_max_local_range()[0];
405406

406-
size_t base = n_vecs * vec_sz *
407+
size_t base = static_cast<size_t>(elems_per_wi) *
407408
(ndit.get_group(0) * ndit.get_local_range(0) +
408409
sg.get_group_id()[0] * sgSize);
409410

410-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
411+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
411412
sycl::vec<argT1, vec_sz> arg1_vec;
412413
sycl::vec<argT2, vec_sz> arg2_vec;
413414
sycl::vec<resT, vec_sz> res_vec;
414415

415416
#pragma unroll
416-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
417+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
417418
size_t offset = base + static_cast<size_t>(it) *
418419
static_cast<size_t>(sgSize);
419420
auto in1_multi_ptr = sycl::address_space_cast<
@@ -433,9 +434,8 @@ struct BinaryContigFunctor
433434
}
434435
}
435436
else {
436-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
437-
k += sgSize)
438-
{
437+
const std::size_t lane_id = sg.get_local_id()[0];
438+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
439439
out[k] = op(in1[k], in2[k]);
440440
}
441441
}
@@ -446,17 +446,17 @@ struct BinaryContigFunctor
446446
auto sg = ndit.get_sub_group();
447447
std::uint8_t sgSize = sg.get_max_local_range()[0];
448448

449-
size_t base = n_vecs * vec_sz *
449+
size_t base = static_cast<size_t>(elems_per_wi) *
450450
(ndit.get_group(0) * ndit.get_local_range(0) +
451451
sg.get_group_id()[0] * sgSize);
452452

453-
if (base + n_vecs * vec_sz * sgSize < nelems_) {
453+
if (base + static_cast<size_t>(elems_per_wi * sgSize) < nelems_) {
454454
sycl::vec<argT1, vec_sz> arg1_vec;
455455
sycl::vec<argT2, vec_sz> arg2_vec;
456456
sycl::vec<resT, vec_sz> res_vec;
457457

458458
#pragma unroll
459-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
459+
for (std::uint32_t it = 0; it < elems_per_wi; it += vec_sz) {
460460
size_t offset = base + static_cast<size_t>(it) *
461461
static_cast<size_t>(sgSize);
462462
auto in1_multi_ptr = sycl::address_space_cast<
@@ -480,20 +480,20 @@ struct BinaryContigFunctor
480480
}
481481
}
482482
else {
483-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
484-
k += sgSize)
485-
{
483+
const std::size_t lane_id = sg.get_local_id()[0];
484+
for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
486485
out[k] = op(in1[k], in2[k]);
487486
}
488487
}
489488
}
490489
else {
491-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
492-
size_t base = ndit.get_global_linear_id();
490+
const size_t sgSize = ndit.get_sub_group().get_local_range()[0];
491+
const size_t gid = ndit.get_global_linear_id();
492+
const size_t elems_per_sg = sgSize * elems_per_wi;
493493

494-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
494+
const size_t base = (gid / sgSize) * elems_per_sg + (gid % sgSize);
495495
for (size_t offset = base;
496-
offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz));
496+
offset < std::min(nelems_, base + elems_per_sg);
497497
offset += sgSize)
498498
{
499499
out[offset] = op(in1[offset], in2[offset]);

0 commit comments

Comments
 (0)