Skip to content

Commit c855fd1

Browse files
authored
[SYCL] Fix sub-group mask for smaller SG sizes (#4916)
Fix accessing sub-group mask when sub-group size is less than 32. Make sure that false is returned for positions that are more than sub-group size. Update the test to check this case.
1 parent 2ebde5f commit c855fd1

File tree

2 files changed

+136
-115
lines changed

2 files changed

+136
-115
lines changed

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct sub_group_mask {
5050
}
5151

5252
reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
53-
RefBit = 1 << pos % word_size;
53+
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
5454
}
5555

5656
private:
@@ -61,16 +61,17 @@ struct sub_group_mask {
6161
};
6262

6363
bool operator[](id<1> id) const {
64-
return Bits & (1 << (id.get(0) % word_size));
64+
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
6565
}
66+
6667
reference operator[](id<1> id) { return {*this, id.get(0)}; }
6768
bool test(id<1> id) const { return operator[](id); }
68-
bool all() const { return !~Bits; }
69-
bool any() const { return Bits; }
70-
bool none() const { return !Bits; }
69+
bool all() const { return count() == bits_num; }
70+
bool any() const { return count() != 0; }
71+
bool none() const { return count() == 0; }
7172
uint32_t count() const {
7273
unsigned int count = 0;
73-
auto word = Bits;
74+
auto word = (Bits & valuable_bits(bits_num));
7475
while (word) {
7576
word &= (word - 1);
7677
count++;
@@ -99,9 +100,9 @@ struct sub_group_mask {
99100
insert_data <<= pos.get(0);
100101
uint32_t mask = 0;
101102
if (pos.get(0) + insert_size < size())
102-
mask |= (0xffffffff << (pos.get(0) + insert_size));
103+
mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
103104
if (pos.get(0) < size() && pos.get(0))
104-
mask |= (0xffffffff >> (size() - pos.get(0)));
105+
mask |= (valuable_bits(max_bits) >> (max_bits - pos.get(0)));
105106
Bits &= mask;
106107
Bits += insert_data;
107108
}
@@ -125,14 +126,15 @@ struct sub_group_mask {
125126
template <typename Type,
126127
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
127128
void extract_bits(Type &bits, id<1> pos = 0) const {
128-
uint32_t Res = Bits;
129+
auto Res = Bits;
130+
Res &= valuable_bits(bits_num);
129131
if (pos.get(0) < size()) {
130132
if (pos.get(0) > 0) {
131133
Res >>= pos.get(0);
132134
}
133135

134-
if (sizeof(Type) * CHAR_BIT < size()) {
135-
Res &= (0xffffffff >> (size() - (sizeof(Type) * CHAR_BIT)));
136+
if (sizeof(Type) * CHAR_BIT < max_bits) {
137+
Res &= valuable_bits(sizeof(Type) * CHAR_BIT);
136138
}
137139
bits = (Type)Res;
138140
} else {
@@ -154,13 +156,13 @@ struct sub_group_mask {
154156
}
155157
}
156158

157-
void set() { Bits = uint32_t{0xffffffff}; }
159+
void set() { Bits = valuable_bits(bits_num); }
158160
void set(id<1> id, bool value = true) { operator[](id) = value; }
159161
void reset() { Bits = uint32_t{0}; }
160162
void reset(id<1> id) { operator[](id) = 0; }
161163
void reset_low() { reset(find_low()); }
162164
void reset_high() { reset(find_high()); }
163-
void flip() { Bits = ~Bits; }
165+
void flip() { Bits = (~Bits & valuable_bits(bits_num)); }
164166
void flip(id<1> id) { operator[](id).flip(); }
165167

166168
bool operator==(const sub_group_mask &rhs) const { return Bits == rhs.Bits; }
@@ -177,11 +179,13 @@ struct sub_group_mask {
177179

178180
sub_group_mask &operator^=(const sub_group_mask &rhs) {
179181
Bits ^= rhs.Bits;
182+
Bits &= valuable_bits(bits_num);
180183
return *this;
181184
}
182185

183186
sub_group_mask &operator<<=(size_t pos) {
184187
Bits <<= pos;
188+
Bits &= valuable_bits(bits_num);
185189
return *this;
186190
}
187191

@@ -239,6 +243,9 @@ struct sub_group_mask {
239243
sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
240244
assert(bits_num <= max_bits);
241245
}
246+
inline uint32_t valuable_bits(size_t bn) const {
247+
return static_cast<uint32_t>((1ULL << bn) - 1ULL);
248+
}
242249
uint32_t Bits;
243250
// Number of valuable bits
244251
size_t bits_num;
Lines changed: 116 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -g -O0 -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
22
// RUN: %t.out
33

44
//==-------- sub_group_mask.cpp - SYCL sub-group mask test -----------------==//
@@ -13,110 +13,124 @@
1313
#include <iostream>
1414

1515
int main() {
16-
auto g = sycl::detail::Builder::createSubGroupMask<
17-
sycl::ext::oneapi::sub_group_mask>(0, 32);
18-
assert(g.none() && !g.any() && !g.all());
19-
assert(g[10] == false); // reference::operator[](id) const;
20-
g[10] = true; // reference::operator=(bool);
21-
assert(g[10] == true);
22-
g[11] = g[10]; // reference::operator=(reference) reference::operator[](id);
23-
assert(g[10].flip() == false); // reference::flip()
24-
assert(~g[10] == true); // refernce::operator~()
25-
assert(g[10] == false);
26-
assert(g[11] == true);
27-
assert(g.test(10) == false && g.test(11) == true);
28-
g.set(30, 1);
29-
g.set(11, 0);
30-
g.set(23, 1);
31-
assert(!g.none() && g.any() && !g.all());
16+
for (size_t sgsize = 32; sgsize > 4; sgsize /= 2) {
17+
std::cout << "Running test for sub-group size = " << sgsize << std::endl;
18+
auto g = sycl::detail::Builder::createSubGroupMask<
19+
sycl::ext::oneapi::sub_group_mask>(0, sgsize);
20+
assert(g.none() && !g.any() && !g.all());
21+
assert(g[5] == false); // reference::operator[](id) const;
22+
g[5] = true; // reference::operator=(bool);
23+
assert(g[5] == true);
24+
g[6] = g[5]; // reference::operator=(reference) reference::operator[](id);
25+
assert(g[5].flip() == false); // reference::flip()
26+
assert(~g[5 % sgsize] == true); // refernce::operator~()
27+
assert(g[5 % sgsize] == false);
28+
assert(g[6 % sgsize] == true);
29+
assert(g.test(5 % sgsize) == false && g.test(6 % sgsize) == true);
30+
g.set(3 % sgsize, 1);
31+
g.set(6 % sgsize, 0);
32+
g.set(2 % sgsize, 1);
33+
assert(!g.none() && g.any() && !g.all());
3234

33-
assert(g.count() == 2);
34-
assert(g.find_low() == 23);
35-
assert(g.find_high() == 30);
36-
assert(g.size() == 32);
35+
assert(g.count() == 2);
36+
assert(g.find_low() == 2 % sgsize);
37+
assert(g.find_high() == 3 % sgsize);
38+
assert(g.size() == sgsize);
3739

38-
g.reset();
39-
assert(g.none() && !g.any() && !g.all());
40-
assert(g.find_low() == g.size() && g.find_high() == g.size());
41-
g.set();
42-
assert(!g.none() && g.any() && g.all());
43-
assert(g.find_low() == 0 && g.find_high() == 31);
44-
g.flip();
45-
assert(g.none() && !g.any() && !g.all());
40+
g.reset();
41+
assert(g.none() && !g.any() && !g.all());
42+
assert(g.find_low() == g.size() && g.find_high() == g.size());
43+
g.set();
44+
assert(!g.none() && g.any() && g.all());
45+
assert(g.find_low() == 0 && g.find_high() == 31 % sgsize);
46+
g.flip();
47+
assert(g.none() && !g.any() && !g.all());
4648

47-
g.flip(13);
48-
g.flip(23);
49-
g.flip(29);
50-
auto b = g;
51-
assert(b == g && !(b != g));
52-
g.flip(31);
53-
assert(g.find_high() == 31);
54-
assert(b.find_high() == 29);
55-
assert(b != g && !(b == g));
56-
b.flip(31);
57-
assert(b == g && !(b != g));
58-
b = g >> 1;
59-
assert(b[12] && b[22] && b[28] && b[30]);
60-
b <<= 1;
61-
assert(b == g);
62-
g ^= ~b;
63-
assert(!g.none() && g.any() && g.all());
64-
assert((g | ~g).all());
65-
assert((g & ~g).none());
66-
assert((g ^ ~g).all());
67-
b.reset_low();
68-
b.reset_high();
69-
assert(!b[13] && b[23] && b[29] && !b[31]);
70-
b.insert_bits(0x01020408);
71-
assert(b[24] && b[17] && b[10] && b[3]);
72-
b <<= 13;
73-
assert(!b[24] && !b[17] && !b[10] && !b[3] && b[30] && b[23] && b[16]);
74-
b.insert_bits((char)0b01010101, 18);
75-
assert(b[18] && b[20] && b[22] && b[24] && b[30] && !b[23] && b[16]);
76-
b[3] = true;
77-
b.insert_bits(sycl::marray<char, 8>{1, 2, 4, 8, 16, 32, 64, 128}, 5);
78-
assert(!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[3] &&
79-
b[5] && b[14] && b[23]);
80-
char r, rbc;
81-
const auto b_const{b};
82-
b.extract_bits(r);
83-
b_const.extract_bits(rbc);
84-
assert(r == 0b00101000);
85-
assert(rbc == 0b00101000);
86-
long r2 = -1, r2bc = -1;
87-
b.extract_bits(r2, 16);
88-
b_const.extract_bits(r2bc, 16);
89-
assert(r2 == 128);
90-
assert(r2bc == 128);
49+
g.flip(2);
50+
g.flip(3);
51+
g.flip(7);
52+
auto b = g;
53+
assert(b == g && !(b != g));
54+
g.flip(7);
55+
assert(g.find_high() == 3 % sgsize);
56+
assert(b.find_high() == 7 % sgsize);
57+
assert(b != g && !(b == g));
58+
g.flip(7);
59+
assert(b == g && !(b != g));
60+
b = g >> 1;
61+
assert(b[1] && b[2] && b[6]);
62+
b <<= 1;
63+
assert(b == g);
64+
g ^= ~b;
65+
assert(!g.none() && g.any() && g.all());
66+
assert((g | ~g).all());
67+
assert((g & ~g).none());
68+
assert((g ^ ~g).all());
69+
b.reset_low();
70+
b.reset_high();
71+
assert(!b[2] && b[3] && !b[7]);
72+
b.insert_bits(0x01020408);
73+
assert(((b[24] && b[17]) || sgsize < 32) && (b[10] || sgsize < 16) && b[3]);
74+
b <<= 10;
75+
assert(((!b[24] && !b[17] && b[27] && b[20]) || sgsize < 32) &&
76+
((!b[10] && b[13]) || sgsize < 16) && !b[3]);
77+
b.insert_bits((char)0b01010101, 6);
78+
assert(b[6] && ((b[8] && b[10] && b[12] && !b[13]) || sgsize < 16));
79+
b[3] = true;
80+
b.insert_bits(sycl::marray<char, 8>{1, 2, 4, 8, 16, 32, 64, 128}, 5);
81+
assert(
82+
((!b[18] && !b[20] && !b[22] && !b[24] && !b[30] && !b[16] && b[23]) ||
83+
sgsize < 32) &&
84+
b[3] && b[5] && (b[14] || sgsize < 16));
85+
b.flip(14);
86+
b.flip(23);
87+
char r, rbc;
88+
const auto b_const{b};
89+
b.extract_bits(r);
90+
b_const.extract_bits(rbc);
91+
assert(r == 0b00101000);
92+
assert(rbc == 0b00101000);
93+
long r2 = -1, r2bc = -1;
94+
b.extract_bits(r2, 3);
95+
b_const.extract_bits(r2bc, 3);
96+
assert(r2 == 5);
97+
assert(r2bc == 5);
9198

92-
b[31] = true;
93-
const auto b_const2{b};
94-
sycl::marray<char, 6> r3{-1}, r3bc{-1};
95-
b.extract_bits(r3, 14);
96-
b_const2.extract_bits(r3bc, 14);
97-
assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]);
98-
assert(r3bc[0] == 1 && r3bc[1] == 2 && r3bc[2] == 2 && !r3bc[3] && !r3bc[4] &&
99-
!r3bc[5]);
100-
int ibits = 0b1010101010101010101010101010101;
101-
b.insert_bits(ibits);
102-
for (size_t i = 0; i < 32; i++) {
103-
assert(b[i] != (bool)(i % 2));
99+
b.insert_bits((uint32_t)0x08040201);
100+
const auto b_const2{b};
101+
sycl::marray<char, 6> r3{-1}, r3bc{-1};
102+
b.extract_bits(r3);
103+
b_const2.extract_bits(r3bc);
104+
assert(r3[0] == 1 && r3[1] == (sgsize > 8 ? 2 : 0) &&
105+
r3[2] == (sgsize > 16 ? 4 : 0) && r3[3] == (sgsize > 16 ? 8 : 0) &&
106+
!r3[4] && !r3[5]);
107+
assert(r3bc[0] == 1 && r3bc[1] == (sgsize > 8 ? 2 : 0) &&
108+
r3bc[2] == (sgsize > 16 ? 4 : 0) &&
109+
r3bc[3] == (sgsize > 16 ? 8 : 0) && !r3bc[4] && !r3bc[5]);
110+
int ibits = 0b1010101010101010101010101010101;
111+
b.insert_bits(ibits);
112+
for (size_t i = 0; i < sgsize; i++) {
113+
assert(b[i] != (bool)(i % 2));
114+
}
115+
short sbits = 0b0111011101110111;
116+
b.insert_bits(sbits, 7);
117+
b.extract_bits(ibits);
118+
assert(ibits ==
119+
(0b1010101001110111011101111010101 & ((1ULL << sgsize) - 1ULL)));
120+
sbits = 0b1100001111000011;
121+
b.insert_bits(sbits, 23);
122+
b.extract_bits(ibits);
123+
if (sgsize >= 32) {
124+
int64_t lbits = -1;
125+
b.extract_bits(lbits, 33);
126+
assert(lbits == 0);
127+
lbits = -1;
128+
b.extract_bits(lbits, 5);
129+
assert(lbits ==
130+
(0b111000011011101110111011110 & ((1ULL << sgsize) - 1ULL)));
131+
lbits = -1;
132+
b.insert_bits(lbits);
133+
assert(b.all());
134+
}
104135
}
105-
short sbits = 0b0111011101110111;
106-
b.insert_bits(sbits, 7);
107-
b.extract_bits(ibits);
108-
assert(ibits == 0b1010101001110111011101111010101);
109-
sbits = 0b1100001111000011;
110-
b.insert_bits(sbits, 23);
111-
b.extract_bits(ibits);
112-
assert(ibits == 0b11100001101110111011101111010101);
113-
int64_t lbits = -1;
114-
b.extract_bits(lbits, 33);
115-
assert(lbits == 0);
116-
lbits = -1;
117-
b.extract_bits(lbits, 5);
118-
assert(lbits == 0b111000011011101110111011110);
119-
lbits = -1;
120-
b.insert_bits(lbits);
121-
assert(b.all());
122136
}

0 commit comments

Comments
 (0)