@@ -72,22 +72,19 @@ struct UnaryContigFunctor
72
72
{
73
73
UnaryOperatorT op{};
74
74
/* Each work-item processes vec_sz elements, contiguous in memory */
75
- /* NOTE: vec_sz must divide sg.max_local_range()[0] */
75
+ /* NOTE: work-group size must be divisible by sub-group size */
76
76
77
77
if constexpr (enable_sg_loadstore && UnaryOperatorT::is_constant::value)
78
78
{
79
79
// value of operator is known to be a known constant
80
80
constexpr resT const_val = UnaryOperatorT::constant_value;
81
81
82
82
auto sg = ndit.get_sub_group ();
83
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
84
- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
83
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
85
84
size_t base = n_vecs * vec_sz *
86
85
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
87
86
sg.get_group_id ()[0 ] * sgSize);
88
- if (base + n_vecs * vec_sz * sgSize < nelems_ &&
89
- max_sgSize == sgSize)
90
- {
87
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
91
88
sycl::vec<resT, vec_sz> res_vec (const_val);
92
89
#pragma unroll
93
90
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
@@ -113,14 +110,11 @@ struct UnaryContigFunctor
113
110
UnaryOperatorT::supports_vec::value)
114
111
{
115
112
auto sg = ndit.get_sub_group ();
116
- std::uint16_t sgSize = sg.get_local_range ()[0 ];
117
- std::uint16_t max_sgSize = sg.get_max_local_range ()[0 ];
113
+ std::uint16_t sgSize = sg.get_max_local_range ()[0 ];
118
114
size_t base = n_vecs * vec_sz *
119
115
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
120
- sg.get_group_id ()[0 ] * max_sgSize);
121
- if (base + n_vecs * vec_sz * sgSize < nelems_ &&
122
- sgSize == max_sgSize)
123
- {
116
+ sg.get_group_id ()[0 ] * sgSize);
117
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
124
118
sycl::vec<argT, vec_sz> x;
125
119
126
120
#pragma unroll
@@ -155,15 +149,12 @@ struct UnaryContigFunctor
155
149
// default: use scalar-value function
156
150
157
151
auto sg = ndit.get_sub_group ();
158
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
159
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
152
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
160
153
size_t base = n_vecs * vec_sz *
161
154
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
162
- sg.get_group_id ()[0 ] * maxsgSize );
155
+ sg.get_group_id ()[0 ] * sgSize );
163
156
164
- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
165
- (maxsgSize == sgSize))
166
- {
157
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
167
158
sycl::vec<argT, vec_sz> arg_vec;
168
159
169
160
#pragma unroll
@@ -199,15 +190,12 @@ struct UnaryContigFunctor
199
190
// default: use scalar-value function
200
191
201
192
auto sg = ndit.get_sub_group ();
202
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
203
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
193
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
204
194
size_t base = n_vecs * vec_sz *
205
195
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
206
- sg.get_group_id ()[0 ] * maxsgSize );
196
+ sg.get_group_id ()[0 ] * sgSize );
207
197
208
- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
209
- (maxsgSize == sgSize))
210
- {
198
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
211
199
sycl::vec<argT, vec_sz> arg_vec;
212
200
sycl::vec<resT, vec_sz> res_vec;
213
201
@@ -406,22 +394,20 @@ struct BinaryContigFunctor
406
394
{
407
395
BinaryOperatorT op{};
408
396
/* Each work-item processes vec_sz elements, contiguous in memory */
397
+ /* NOTE: work-group size must be divisible by sub-group size */
409
398
410
399
if constexpr (enable_sg_loadstore &&
411
400
BinaryOperatorT::supports_sg_loadstore::value &&
412
401
BinaryOperatorT::supports_vec::value)
413
402
{
414
403
auto sg = ndit.get_sub_group ();
415
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
416
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
404
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
417
405
418
406
size_t base = n_vecs * vec_sz *
419
407
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
420
408
sg.get_group_id ()[0 ] * sgSize);
421
409
422
- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
423
- (sgSize == maxsgSize))
424
- {
410
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
425
411
sycl::vec<argT1, vec_sz> arg1_vec;
426
412
sycl::vec<argT2, vec_sz> arg2_vec;
427
413
sycl::vec<resT, vec_sz> res_vec;
@@ -458,16 +444,13 @@ struct BinaryContigFunctor
458
444
BinaryOperatorT::supports_sg_loadstore::value)
459
445
{
460
446
auto sg = ndit.get_sub_group ();
461
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
462
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
447
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
463
448
464
449
size_t base = n_vecs * vec_sz *
465
450
(ndit.get_group (0 ) * ndit.get_local_range (0 ) +
466
451
sg.get_group_id ()[0 ] * sgSize);
467
452
468
- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
469
- (sgSize == maxsgSize))
470
- {
453
+ if (base + n_vecs * vec_sz * sgSize < nelems_) {
471
454
sycl::vec<argT1, vec_sz> arg1_vec;
472
455
sycl::vec<argT2, vec_sz> arg2_vec;
473
456
sycl::vec<resT, vec_sz> res_vec;
@@ -582,13 +565,15 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
582
565
583
566
void operator ()(sycl::nd_item<1 > ndit) const
584
567
{
568
+ /* NOTE: work-group size must be divisible by sub-group size */
569
+
585
570
BinaryOperatorT op{};
586
571
static_assert (BinaryOperatorT::supports_sg_loadstore::value);
587
572
588
573
auto sg = ndit.get_sub_group ();
589
574
size_t gid = ndit.get_global_linear_id ();
590
575
591
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
576
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
592
577
size_t base = gid - sg.get_local_id ()[0 ];
593
578
594
579
if (base + sgSize < n_elems) {
@@ -647,13 +632,14 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
647
632
648
633
void operator ()(sycl::nd_item<1 > ndit) const
649
634
{
635
+ /* NOTE: work-group size must be divisible by sub-group size */
650
636
BinaryOperatorT op{};
651
637
static_assert (BinaryOperatorT::supports_sg_loadstore::value);
652
638
653
639
auto sg = ndit.get_sub_group ();
654
640
size_t gid = ndit.get_global_linear_id ();
655
641
656
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
642
+ std::uint8_t sgSize = sg.get_max_local_range ()[0 ];
657
643
size_t base = gid - sg.get_local_id ()[0 ];
658
644
659
645
if (base + sgSize < n_elems) {
0 commit comments