Skip to content

Commit f88a19e

Browse files
authored
[SYCL] Fix overwriting insert to sub_group_mask (intel#4656)
Make sure that old value is cleared when inserted bits ovewrite whole mask.
1 parent 7a93d99 commit f88a19e

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct sub_group_mask {
100100
uint32_t mask = 0;
101101
if (pos.get(0) + insert_size < size())
102102
mask |= (0xffffffff << (pos.get(0) + insert_size));
103-
if (pos.get(0) < size())
103+
if (pos.get(0) < size() && pos.get(0))
104104
mask |= (0xffffffff >> (size() - pos.get(0)));
105105
Bits &= mask;
106106
Bits += insert_data;

sycl/test/extensions/sub_group_mask.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,26 @@ int main() {
8787
sycl::marray<char, 6> r3{-1};
8888
b.extract_bits(r3, 14);
8989
assert(r3[0] == 1 && r3[1] == 2 && r3[2] == 2 && !r3[3] && !r3[4] && !r3[5]);
90+
int ibits = 0b1010101010101010101010101010101;
91+
b.insert_bits(ibits);
92+
for (size_t i = 0; i < 32; i++) {
93+
assert(b[i] != (bool)(i % 2));
94+
}
95+
short sbits = 0b0111011101110111;
96+
b.insert_bits(sbits, 7);
97+
b.extract_bits(ibits);
98+
assert(ibits == 0b1010101001110111011101111010101);
99+
sbits = 0b1100001111000011;
100+
b.insert_bits(sbits, 23);
101+
b.extract_bits(ibits);
102+
assert(ibits == 0b11100001101110111011101111010101);
103+
int64_t lbits = -1;
104+
b.extract_bits(lbits, 33);
105+
assert(lbits == 0);
106+
lbits = -1;
107+
b.extract_bits(lbits, 5);
108+
assert(lbits == 0b111000011011101110111011110);
109+
lbits = -1;
110+
b.insert_bits(lbits);
111+
assert(b.all());
90112
}

0 commit comments

Comments
 (0)