Skip to content

Commit 20a1966

Browse files
Do not use sg.get_local_range
Use sg.get_max_local_range instead. The `sg.get_local_range` must perform lots of checks to determine if this is the last trailing sub-group in the work-group and its actual size may be smaller. We set the local work-group size to be 128, which is a multiple of any sub-group size, and hence get_local_range() always equals to get_max_local_raneg().
1 parent 286afae commit 20a1966

File tree

2 files changed

+29
-47
lines changed

2 files changed

+29
-47
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,19 @@ struct UnaryContigFunctor
7272
{
7373
UnaryOperatorT op{};
7474
/* Each work-item processes vec_sz elements, contiguous in memory */
75-
/* NOTE: vec_sz must divide sg.max_local_range()[0] */
75+
/* NOTE: work-group size must be divisible by sub-group size */
7676

7777
if constexpr (enable_sg_loadstore && UnaryOperatorT::is_constant::value)
7878
{
7979
// value of operator is known to be a known constant
8080
constexpr resT const_val = UnaryOperatorT::constant_value;
8181

8282
auto sg = ndit.get_sub_group();
83-
std::uint8_t sgSize = sg.get_local_range()[0];
84-
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
83+
std::uint8_t sgSize = sg.get_max_local_range()[0];
8584
size_t base = n_vecs * vec_sz *
8685
(ndit.get_group(0) * ndit.get_local_range(0) +
8786
sg.get_group_id()[0] * sgSize);
88-
if (base + n_vecs * vec_sz * sgSize < nelems_ &&
89-
max_sgSize == sgSize)
90-
{
87+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
9188
sycl::vec<resT, vec_sz> res_vec(const_val);
9289
#pragma unroll
9390
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
@@ -113,14 +110,11 @@ struct UnaryContigFunctor
113110
UnaryOperatorT::supports_vec::value)
114111
{
115112
auto sg = ndit.get_sub_group();
116-
std::uint16_t sgSize = sg.get_local_range()[0];
117-
std::uint16_t max_sgSize = sg.get_max_local_range()[0];
113+
std::uint16_t sgSize = sg.get_max_local_range()[0];
118114
size_t base = n_vecs * vec_sz *
119115
(ndit.get_group(0) * ndit.get_local_range(0) +
120-
sg.get_group_id()[0] * max_sgSize);
121-
if (base + n_vecs * vec_sz * sgSize < nelems_ &&
122-
sgSize == max_sgSize)
123-
{
116+
sg.get_group_id()[0] * sgSize);
117+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
124118
sycl::vec<argT, vec_sz> x;
125119

126120
#pragma unroll
@@ -155,15 +149,12 @@ struct UnaryContigFunctor
155149
// default: use scalar-value function
156150

157151
auto sg = ndit.get_sub_group();
158-
std::uint8_t sgSize = sg.get_local_range()[0];
159-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
152+
std::uint8_t sgSize = sg.get_max_local_range()[0];
160153
size_t base = n_vecs * vec_sz *
161154
(ndit.get_group(0) * ndit.get_local_range(0) +
162-
sg.get_group_id()[0] * maxsgSize);
155+
sg.get_group_id()[0] * sgSize);
163156

164-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
165-
(maxsgSize == sgSize))
166-
{
157+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
167158
sycl::vec<argT, vec_sz> arg_vec;
168159

169160
#pragma unroll
@@ -199,15 +190,12 @@ struct UnaryContigFunctor
199190
// default: use scalar-value function
200191

201192
auto sg = ndit.get_sub_group();
202-
std::uint8_t sgSize = sg.get_local_range()[0];
203-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
193+
std::uint8_t sgSize = sg.get_max_local_range()[0];
204194
size_t base = n_vecs * vec_sz *
205195
(ndit.get_group(0) * ndit.get_local_range(0) +
206-
sg.get_group_id()[0] * maxsgSize);
196+
sg.get_group_id()[0] * sgSize);
207197

208-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
209-
(maxsgSize == sgSize))
210-
{
198+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
211199
sycl::vec<argT, vec_sz> arg_vec;
212200
sycl::vec<resT, vec_sz> res_vec;
213201

@@ -406,22 +394,20 @@ struct BinaryContigFunctor
406394
{
407395
BinaryOperatorT op{};
408396
/* Each work-item processes vec_sz elements, contiguous in memory */
397+
/* NOTE: work-group size must be divisible by sub-group size */
409398

410399
if constexpr (enable_sg_loadstore &&
411400
BinaryOperatorT::supports_sg_loadstore::value &&
412401
BinaryOperatorT::supports_vec::value)
413402
{
414403
auto sg = ndit.get_sub_group();
415-
std::uint8_t sgSize = sg.get_local_range()[0];
416-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
404+
std::uint8_t sgSize = sg.get_max_local_range()[0];
417405

418406
size_t base = n_vecs * vec_sz *
419407
(ndit.get_group(0) * ndit.get_local_range(0) +
420408
sg.get_group_id()[0] * sgSize);
421409

422-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
423-
(sgSize == maxsgSize))
424-
{
410+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
425411
sycl::vec<argT1, vec_sz> arg1_vec;
426412
sycl::vec<argT2, vec_sz> arg2_vec;
427413
sycl::vec<resT, vec_sz> res_vec;
@@ -458,16 +444,13 @@ struct BinaryContigFunctor
458444
BinaryOperatorT::supports_sg_loadstore::value)
459445
{
460446
auto sg = ndit.get_sub_group();
461-
std::uint8_t sgSize = sg.get_local_range()[0];
462-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
447+
std::uint8_t sgSize = sg.get_max_local_range()[0];
463448

464449
size_t base = n_vecs * vec_sz *
465450
(ndit.get_group(0) * ndit.get_local_range(0) +
466451
sg.get_group_id()[0] * sgSize);
467452

468-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
469-
(sgSize == maxsgSize))
470-
{
453+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
471454
sycl::vec<argT1, vec_sz> arg1_vec;
472455
sycl::vec<argT2, vec_sz> arg2_vec;
473456
sycl::vec<resT, vec_sz> res_vec;
@@ -582,13 +565,15 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
582565

583566
void operator()(sycl::nd_item<1> ndit) const
584567
{
568+
/* NOTE: work-group size must be divisible by sub-group size */
569+
585570
BinaryOperatorT op{};
586571
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
587572

588573
auto sg = ndit.get_sub_group();
589574
size_t gid = ndit.get_global_linear_id();
590575

591-
std::uint8_t sgSize = sg.get_local_range()[0];
576+
std::uint8_t sgSize = sg.get_max_local_range()[0];
592577
size_t base = gid - sg.get_local_id()[0];
593578

594579
if (base + sgSize < n_elems) {
@@ -647,13 +632,14 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
647632

648633
void operator()(sycl::nd_item<1> ndit) const
649634
{
635+
/* NOTE: work-group size must be divisible by sub-group size */
650636
BinaryOperatorT op{};
651637
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
652638

653639
auto sg = ndit.get_sub_group();
654640
size_t gid = ndit.get_global_linear_id();
655641

656-
std::uint8_t sgSize = sg.get_local_range()[0];
642+
std::uint8_t sgSize = sg.get_max_local_range()[0];
657643
size_t base = gid - sg.get_local_id()[0];
658644

659645
if (base + sgSize < n_elems) {

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,20 @@ struct BinaryInplaceContigFunctor
7373
{
7474
BinaryInplaceOperatorT op{};
7575
/* Each work-item processes vec_sz elements, contiguous in memory */
76+
/* NB: Workgroup size must be divisible by sub-group size */
7677

7778
if constexpr (enable_sg_loadstore &&
7879
BinaryInplaceOperatorT::supports_sg_loadstore::value &&
7980
BinaryInplaceOperatorT::supports_vec::value)
8081
{
8182
auto sg = ndit.get_sub_group();
82-
std::uint8_t sgSize = sg.get_local_range()[0];
83-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
83+
std::uint8_t sgSize = sg.get_max_local_range()[0];
8484

8585
size_t base = n_vecs * vec_sz *
8686
(ndit.get_group(0) * ndit.get_local_range(0) +
8787
sg.get_group_id()[0] * sgSize);
8888

89-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
90-
(sgSize == maxsgSize))
91-
{
89+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
9290

9391
sycl::vec<argT, vec_sz> arg_vec;
9492
sycl::vec<resT, vec_sz> res_vec;
@@ -121,16 +119,13 @@ struct BinaryInplaceContigFunctor
121119
BinaryInplaceOperatorT::supports_sg_loadstore::value)
122120
{
123121
auto sg = ndit.get_sub_group();
124-
std::uint8_t sgSize = sg.get_local_range()[0];
125-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
122+
std::uint8_t sgSize = sg.get_max_local_range()[0];
126123

127124
size_t base = n_vecs * vec_sz *
128125
(ndit.get_group(0) * ndit.get_local_range(0) +
129126
sg.get_group_id()[0] * sgSize);
130127

131-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
132-
(sgSize == maxsgSize))
133-
{
128+
if (base + n_vecs * vec_sz * sgSize < nelems_) {
134129
sycl::vec<argT, vec_sz> arg_vec;
135130
sycl::vec<resT, vec_sz> res_vec;
136131

@@ -228,13 +223,14 @@ struct BinaryInplaceRowMatrixBroadcastingFunctor
228223

229224
void operator()(sycl::nd_item<1> ndit) const
230225
{
226+
/* Workgroup size is expected to be a multiple of sub-group size */
231227
BinaryOperatorT op{};
232228
static_assert(BinaryOperatorT::supports_sg_loadstore::value);
233229

234230
auto sg = ndit.get_sub_group();
235231
size_t gid = ndit.get_global_linear_id();
236232

237-
std::uint8_t sgSize = sg.get_local_range()[0];
233+
std::uint8_t sgSize = sg.get_max_local_range()[0];
238234
size_t base = gid - sg.get_local_id()[0];
239235

240236
if (base + sgSize < n_elems) {

0 commit comments

Comments
 (0)