@@ -70,6 +70,7 @@ struct UnaryContigFunctor
70
70
71
71
void operator ()(sycl::nd_item<1 > ndit) const
72
72
{
73
+ constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
73
74
UnaryOperatorT op{};
74
75
/* Each work-item processes vec_sz elements, contiguous in memory */
75
76
/* NOTE: work-group size must be divisible by sub-group size */
@@ -80,14 +81,15 @@ struct UnaryContigFunctor
80
81
constexpr resT const_val = UnaryOperatorT::constant_value;
81
82
82
83
auto sg = ndit.get_sub_group ();
83
- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
84
- size_t base = n_vecs * vec_sz *
84
+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
85
+
86
+ size_t base = static_cast <size_t >(elems_per_wi) *
85
87
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
86
88
sg.get_group_id ()[0 ] * sgSize);
87
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
89
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
88
90
sycl::vec<resT, vec_sz> res_vec (const_val);
89
91
#pragma unroll
90
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
92
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
91
93
size_t offset = base + static_cast <size_t >(it) *
92
94
static_cast <size_t >(sgSize);
93
95
auto out_multi_ptr = sycl::address_space_cast<
@@ -98,9 +100,8 @@ struct UnaryContigFunctor
98
100
}
99
101
}
100
102
else {
101
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
102
- k += sgSize)
103
- {
103
+ const size_t lane_id = sg.get_local_id ()[0 ];
104
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
104
105
out[k] = const_val;
105
106
}
106
107
}
@@ -110,15 +111,16 @@ struct UnaryContigFunctor
110
111
UnaryOperatorT::supports_vec::value)
111
112
{
112
113
auto sg = ndit.get_sub_group ();
113
- std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
114
- size_t base = n_vecs * vec_sz *
114
+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
115
+
116
+ size_t base = static_cast <size_t >(elems_per_wi) *
115
117
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
116
118
sg.get_group_id ()[0 ] * sgSize);
117
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
119
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
118
120
sycl::vec<argT, vec_sz> x;
119
121
120
122
#pragma unroll
121
- for (std::uint16_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
123
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
122
124
size_t offset = base + static_cast <size_t >(it) *
123
125
static_cast <size_t >(sgSize);
124
126
auto in_multi_ptr = sycl::address_space_cast<
@@ -134,9 +136,8 @@ struct UnaryContigFunctor
134
136
}
135
137
}
136
138
else {
137
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
138
- k += sgSize)
139
- {
139
+ const size_t lane_id = sg.get_local_id ()[0 ];
140
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
140
141
// scalar call
141
142
out[k] = op (in[k]);
142
143
}
@@ -149,16 +150,16 @@ struct UnaryContigFunctor
149
150
// default: use scalar-value function
150
151
151
152
auto sg = ndit.get_sub_group ();
152
- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
153
- size_t base = n_vecs * vec_sz *
153
+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
154
+ size_t base = static_cast < size_t >(elems_per_wi) *
154
155
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
155
156
sg.get_group_id ()[0 ] * sgSize);
156
157
157
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
158
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
158
159
sycl::vec<argT, vec_sz> arg_vec;
159
160
160
161
#pragma unroll
161
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
162
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
162
163
size_t offset = base + static_cast <size_t >(it) *
163
164
static_cast <size_t >(sgSize);
164
165
auto in_multi_ptr = sycl::address_space_cast<
@@ -170,16 +171,15 @@ struct UnaryContigFunctor
170
171
171
172
arg_vec = sg.load <vec_sz>(in_multi_ptr);
172
173
#pragma unroll
173
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
174
+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
174
175
arg_vec[k] = op (arg_vec[k]);
175
176
}
176
177
sg.store <vec_sz>(out_multi_ptr, arg_vec);
177
178
}
178
179
}
179
180
else {
180
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
181
- k += sgSize)
182
- {
181
+ const size_t lane_id = sg.get_local_id ()[0 ];
182
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
183
183
out[k] = op (in[k]);
184
184
}
185
185
}
@@ -190,17 +190,17 @@ struct UnaryContigFunctor
190
190
// default: use scalar-value function
191
191
192
192
auto sg = ndit.get_sub_group ();
193
- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
194
- size_t base = n_vecs * vec_sz *
193
+ std::uint32_t sgSize = sg.get_max_local_range ()[0 ];
194
+ size_t base = static_cast < size_t >(elems_per_wi) *
195
195
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
196
196
sg.get_group_id ()[0 ] * sgSize);
197
197
198
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
198
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
199
199
sycl::vec<argT, vec_sz> arg_vec;
200
200
sycl::vec<resT, vec_sz> res_vec;
201
201
202
202
#pragma unroll
203
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
203
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
204
204
size_t offset = base + static_cast <size_t >(it) *
205
205
static_cast <size_t >(sgSize);
206
206
auto in_multi_ptr = sycl::address_space_cast<
@@ -212,27 +212,27 @@ struct UnaryContigFunctor
212
212
213
213
arg_vec = sg.load <vec_sz>(in_multi_ptr);
214
214
#pragma unroll
215
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
215
+ for (std::uint32_t k = 0 ; k < vec_sz; ++k) {
216
216
res_vec[k] = op (arg_vec[k]);
217
217
}
218
218
sg.store <vec_sz>(out_multi_ptr, res_vec);
219
219
}
220
220
}
221
221
else {
222
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
223
- k += sgSize)
224
- {
222
+ const size_t lane_id = sg.get_local_id ()[0 ];
223
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
225
224
out[k] = op (in[k]);
226
225
}
227
226
}
228
227
}
229
228
else {
230
- std:: uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
229
+ size_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
231
230
size_t base = ndit.get_global_linear_id ();
231
+ const size_t elems_per_sg = sgSize * elems_per_wi;
232
232
233
- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
233
+ base = (base / sgSize) * elems_per_sg + (base % sgSize);
234
234
for (size_t offset = base;
235
- offset < std::min (nelems_, base + sgSize * (n_vecs * vec_sz) );
235
+ offset < std::min (nelems_, base + elems_per_sg );
236
236
offset += sgSize)
237
237
{
238
238
out[offset] = op (in[offset]);
@@ -392,6 +392,7 @@ struct BinaryContigFunctor
392
392
393
393
void operator ()(sycl::nd_item<1 > ndit) const
394
394
{
395
+ constexpr std::uint32_t elems_per_wi = n_vecs * vec_sz;
395
396
BinaryOperatorT op{};
396
397
/* Each work-item processes vec_sz elements, contiguous in memory */
397
398
/* NOTE: work-group size must be divisible by sub-group size */
@@ -401,19 +402,19 @@ struct BinaryContigFunctor
401
402
BinaryOperatorT::supports_vec::value)
402
403
{
403
404
auto sg = ndit.get_sub_group ();
404
- std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
405
+ std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
405
406
406
- size_t base = n_vecs * vec_sz *
407
+ size_t base = static_cast < size_t >(elems_per_wi) *
407
408
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
408
409
sg.get_group_id ()[0 ] * sgSize);
409
410
410
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
411
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
411
412
sycl::vec<argT1, vec_sz> arg1_vec;
412
413
sycl::vec<argT2, vec_sz> arg2_vec;
413
414
sycl::vec<resT, vec_sz> res_vec;
414
415
415
416
#pragma unroll
416
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
417
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
417
418
size_t offset = base + static_cast <size_t >(it) *
418
419
static_cast <size_t >(sgSize);
419
420
auto in1_multi_ptr = sycl::address_space_cast<
@@ -433,9 +434,8 @@ struct BinaryContigFunctor
433
434
}
434
435
}
435
436
else {
436
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
437
- k += sgSize)
438
- {
437
+ const std::size_t lane_id = sg.get_local_id ()[0 ];
438
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
439
439
out[k] = op (in1[k], in2[k]);
440
440
}
441
441
}
@@ -446,17 +446,17 @@ struct BinaryContigFunctor
446
446
auto sg = ndit.get_sub_group ();
447
447
std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
448
448
449
- size_t base = n_vecs * vec_sz *
449
+ size_t base = static_cast < size_t >(elems_per_wi) *
450
450
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
451
451
sg.get_group_id ()[0 ] * sgSize);
452
452
453
- if (base + n_vecs * vec_sz * sgSize < nelems_) {
453
+ if (base + static_cast < size_t >(elems_per_wi * sgSize) < nelems_) {
454
454
sycl::vec<argT1, vec_sz> arg1_vec;
455
455
sycl::vec<argT2, vec_sz> arg2_vec;
456
456
sycl::vec<resT, vec_sz> res_vec;
457
457
458
458
#pragma unroll
459
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz ; it += vec_sz) {
459
+ for (std::uint32_t it = 0 ; it < elems_per_wi ; it += vec_sz) {
460
460
size_t offset = base + static_cast <size_t >(it) *
461
461
static_cast <size_t >(sgSize);
462
462
auto in1_multi_ptr = sycl::address_space_cast<
@@ -480,20 +480,20 @@ struct BinaryContigFunctor
480
480
}
481
481
}
482
482
else {
483
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
484
- k += sgSize)
485
- {
483
+ const std::size_t lane_id = sg.get_local_id ()[0 ];
484
+ for (size_t k = base + lane_id; k < nelems_; k += sgSize) {
486
485
out[k] = op (in1[k], in2[k]);
487
486
}
488
487
}
489
488
}
490
489
else {
491
- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
492
- size_t base = ndit.get_global_linear_id ();
490
+ const size_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
491
+ const size_t gid = ndit.get_global_linear_id ();
492
+ const size_t elems_per_sg = sgSize * elems_per_wi;
493
493
494
- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
494
+ const size_t base = (gid / sgSize) * elems_per_sg + (gid % sgSize);
495
495
for (size_t offset = base;
496
- offset < std::min (nelems_, base + sgSize * (n_vecs * vec_sz) );
496
+ offset < std::min (nelems_, base + elems_per_sg );
497
497
offset += sgSize)
498
498
{
499
499
out[offset] = op (in1[offset], in2[offset]);
0 commit comments