Skip to content

Commit 4bc9745

Browse files
[SYCL] Fix ballot_group when the sub-group is not full size (#12737)
Not all sub-groups are necessarily the max size of sub-groups in the kernel invocation. As such, non-uniform groups should handle these sub-groups properly. However, due to how the mask for the false-group in ballot_group creates its mask, it thinks it has full 32-element size no matter how big the actual sub-group is. This commit fixes this issue. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 358843a commit 4bc9745

File tree

6 files changed

+186
-147
lines changed

6 files changed

+186
-147
lines changed

sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,13 @@ get_ballot_group(Group group, bool predicate) {
153153
if (predicate) {
154154
return ballot_group<sycl::sub_group>(mask, predicate);
155155
} else {
156-
return ballot_group<sycl::sub_group>(~mask, predicate);
156+
// To negate the mask for the false-predicate group, we also need to exclude
157+
// all parts of the mask that is not part of the group.
158+
sub_group_mask::BitsType participant_filter =
159+
(~sub_group_mask::BitsType{0}) >>
160+
(sub_group_mask::max_bits - group.get_local_linear_range());
161+
return ballot_group<sycl::sub_group>((~mask) & participant_filter,
162+
predicate);
157163
}
158164
#endif
159165
#else

sycl/include/sycl/ext/oneapi/experimental/fixed_size_group.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ template <size_t PartitionSize, typename ParentGroup> class fixed_size_group {
6464

6565
range_type get_group_range() const {
6666
#ifdef __SYCL_DEVICE_ONLY__
67-
return __spirv_SubgroupMaxSize() / PartitionSize;
67+
return __spirv_SubgroupSize() / PartitionSize;
6868
#else
6969
throw runtime_error("Non-uniform groups are not supported on host device.",
7070
PI_ERROR_INVALID_DEVICE);

sycl/test-e2e/NonUniformGroups/ballot_group.cpp

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,51 @@ int main() {
2020
return 0;
2121
}
2222

23-
sycl::buffer<bool, 1> MatchBuf{sycl::range{32}};
24-
sycl::buffer<bool, 1> LeaderBuf{sycl::range{32}};
25-
26-
const auto NDR = sycl::nd_range<1>{32, 32};
27-
Q.submit([&](sycl::handler &CGH) {
28-
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
29-
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
30-
const auto KernelFunc =
31-
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
32-
auto WI = item.get_global_id();
33-
auto SG = item.get_sub_group();
34-
35-
// Split into odd and even work-items.
36-
bool Predicate = WI % 2 == 0;
37-
auto BallotGroup = syclex::get_ballot_group(SG, Predicate);
38-
39-
// Check function return values match Predicate.
40-
// NB: Test currently uses exactly one sub-group, but we use SG
41-
// below in case this changes in future.
42-
bool Match = true;
43-
auto GroupID = (Predicate) ? 1 : 0;
44-
auto LocalID = SG.get_local_id() / 2;
45-
Match &= (BallotGroup.get_group_id() == GroupID);
46-
Match &= (BallotGroup.get_local_id() == LocalID);
47-
Match &= (BallotGroup.get_group_range() == 2);
48-
Match &= (BallotGroup.get_local_range() == 16);
49-
MatchAcc[WI] = Match;
50-
LeaderAcc[WI] = BallotGroup.leader();
51-
};
52-
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
53-
});
54-
55-
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
56-
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
57-
for (int WI = 0; WI < 32; ++WI) {
58-
assert(MatchAcc[WI] == true);
59-
assert(LeaderAcc[WI] == (WI < 2));
23+
// Test for both the full sub-group size and a case with less work than a full
24+
// sub-group.
25+
for (size_t WGS : std::array<size_t, 2>{32, 16}) {
26+
std::cout << "Testing for work size " << WGS << std::endl;
27+
28+
sycl::buffer<bool, 1> MatchBuf{sycl::range{WGS}};
29+
sycl::buffer<bool, 1> LeaderBuf{sycl::range{WGS}};
30+
31+
const auto NDR = sycl::nd_range<1>{WGS, WGS};
32+
Q.submit([&](sycl::handler &CGH) {
33+
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
34+
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
35+
const auto KernelFunc =
36+
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
37+
auto WI = item.get_global_id();
38+
auto SG = item.get_sub_group();
39+
40+
// Split into odd and even work-items.
41+
bool Predicate = WI % 2 == 0;
42+
auto BallotGroup = syclex::get_ballot_group(SG, Predicate);
43+
44+
// Check function return values match Predicate.
45+
// NB: Test currently uses exactly one sub-group, but we use SG
46+
// below in case this changes in future.
47+
bool Match = true;
48+
auto GroupID = (Predicate) ? 1 : 0;
49+
auto LocalID = SG.get_local_id() / 2;
50+
Match &= (BallotGroup.get_group_id() == GroupID);
51+
Match &= (BallotGroup.get_local_id() == LocalID);
52+
Match &= (BallotGroup.get_group_range() == 2);
53+
Match &= (BallotGroup.get_local_range() ==
54+
SG.get_local_linear_range() / 2);
55+
MatchAcc[WI] = Match;
56+
LeaderAcc[WI] = BallotGroup.leader();
57+
};
58+
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
59+
});
60+
61+
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
62+
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
63+
for (int WI = 0; WI < WGS; ++WI) {
64+
assert(MatchAcc[WI] == true);
65+
assert(LeaderAcc[WI] == (WI < 2));
66+
}
6067
}
68+
6169
return 0;
6270
}

sycl/test-e2e/NonUniformGroups/fixed_size_group.cpp

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,47 @@ template <size_t PartitionSize> class TestKernel;
1414
template <size_t PartitionSize> void test() {
1515
sycl::queue Q;
1616

17-
sycl::buffer<bool, 1> MatchBuf{sycl::range{32}};
18-
sycl::buffer<bool, 1> LeaderBuf{sycl::range{32}};
19-
20-
const auto NDR = sycl::nd_range<1>{32, 32};
21-
Q.submit([&](sycl::handler &CGH) {
22-
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
23-
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
24-
const auto KernelFunc =
25-
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
26-
auto WI = item.get_global_id();
27-
auto SG = item.get_sub_group();
28-
29-
auto Partition = syclex::get_fixed_size_group<PartitionSize>(SG);
30-
31-
bool Match = true;
32-
Match &= (Partition.get_group_id() == (WI / PartitionSize));
33-
Match &= (Partition.get_local_id() == (WI % PartitionSize));
34-
Match &= (Partition.get_group_range() == (32 / PartitionSize));
35-
Match &= (Partition.get_local_range() == PartitionSize);
36-
MatchAcc[WI] = Match;
37-
LeaderAcc[WI] = Partition.leader();
38-
};
39-
CGH.parallel_for<TestKernel<PartitionSize>>(NDR, KernelFunc);
40-
});
41-
42-
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
43-
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
44-
for (int WI = 0; WI < 32; ++WI) {
45-
assert(MatchAcc[WI] == true);
46-
assert(LeaderAcc[WI] == ((WI % PartitionSize) == 0));
17+
// Test for both the full sub-group size and a case with less work than a full
18+
// sub-group.
19+
for (size_t WGS : std::array<size_t, 2>{32, 16}) {
20+
if (WGS < PartitionSize)
21+
continue;
22+
23+
std::cout << "Testing for work size " << WGS << " and partition size "
24+
<< PartitionSize << std::endl;
25+
26+
sycl::buffer<bool, 1> MatchBuf{sycl::range{WGS}};
27+
sycl::buffer<bool, 1> LeaderBuf{sycl::range{WGS}};
28+
29+
const auto NDR = sycl::nd_range<1>{WGS, WGS};
30+
Q.submit([&](sycl::handler &CGH) {
31+
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
32+
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
33+
const auto KernelFunc =
34+
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
35+
auto WI = item.get_global_id();
36+
auto SG = item.get_sub_group();
37+
auto SGS = SG.get_local_linear_range();
38+
39+
auto Partition = syclex::get_fixed_size_group<PartitionSize>(SG);
40+
41+
bool Match = true;
42+
Match &= (Partition.get_group_id() == (WI / PartitionSize));
43+
Match &= (Partition.get_local_id() == (WI % PartitionSize));
44+
Match &= (Partition.get_group_range() == (SGS / PartitionSize));
45+
Match &= (Partition.get_local_range() == PartitionSize);
46+
MatchAcc[WI] = Match;
47+
LeaderAcc[WI] = Partition.leader();
48+
};
49+
CGH.parallel_for<TestKernel<PartitionSize>>(NDR, KernelFunc);
50+
});
51+
52+
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
53+
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
54+
for (int WI = 0; WI < WGS; ++WI) {
55+
assert(MatchAcc[WI] == true);
56+
assert(LeaderAcc[WI] == ((WI % PartitionSize) == 0));
57+
}
4758
}
4859
}
4960

sycl/test-e2e/NonUniformGroups/opportunistic_group.cpp

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,56 @@ int main() {
2020
return 0;
2121
}
2222

23-
sycl::buffer<bool, 1> MatchBuf{sycl::range{32}};
24-
sycl::buffer<bool, 1> LeaderBuf{sycl::range{32}};
23+
// Test for both the full sub-group size and a case with less work than a full
24+
// sub-group.
25+
for (size_t WGS : std::array<size_t, 2>{32, 16}) {
26+
std::cout << "Testing for work size " << WGS << std::endl;
2527

26-
const auto NDR = sycl::nd_range<1>{32, 32};
27-
Q.submit([&](sycl::handler &CGH) {
28-
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
29-
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
30-
const auto KernelFunc =
31-
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
32-
auto WI = item.get_global_id();
33-
auto SG = item.get_sub_group();
28+
sycl::buffer<bool, 1> MatchBuf{sycl::range{WGS}};
29+
sycl::buffer<bool, 1> LeaderBuf{sycl::range{WGS}};
3430

35-
// Due to the unpredictable runtime behavior of opportunistic groups,
36-
// some values may change from run to run. Check they're in expected
37-
// ranges and consistent with other groups.
38-
if (item.get_global_id() % 2 == 0) {
39-
auto OpportunisticGroup =
40-
syclex::this_kernel::get_opportunistic_group();
31+
const auto NDR = sycl::nd_range<1>{WGS, WGS};
32+
Q.submit([&](sycl::handler &CGH) {
33+
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
34+
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
35+
const auto KernelFunc =
36+
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
37+
auto WI = item.get_global_id();
38+
auto SG = item.get_sub_group();
4139

42-
bool Match = true;
43-
Match &= (OpportunisticGroup.get_group_id() == 0);
44-
Match &= (OpportunisticGroup.get_local_id() <
45-
OpportunisticGroup.get_local_range());
46-
Match &= (OpportunisticGroup.get_group_range() == 1);
47-
Match &= (OpportunisticGroup.get_local_linear_range() <=
48-
SG.get_local_linear_range());
49-
MatchAcc[WI] = Match;
50-
LeaderAcc[WI] = OpportunisticGroup.leader();
51-
}
52-
};
53-
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
54-
});
40+
// Due to the unpredictable runtime behavior of opportunistic
41+
// groups, some values may change from run to run. Check they're in
42+
// expected ranges and consistent with other groups.
43+
if (item.get_global_id() % 2 == 0) {
44+
auto OpportunisticGroup =
45+
syclex::this_kernel::get_opportunistic_group();
5546

56-
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
57-
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
58-
uint32_t NumLeaders = 0;
59-
for (int WI = 0; WI < 32; ++WI) {
60-
if (WI % 2 == 0) {
61-
assert(MatchAcc[WI] == true);
62-
if (LeaderAcc[WI]) {
63-
NumLeaders++;
47+
bool Match = true;
48+
Match &= (OpportunisticGroup.get_group_id() == 0);
49+
Match &= (OpportunisticGroup.get_local_id() <
50+
OpportunisticGroup.get_local_range());
51+
Match &= (OpportunisticGroup.get_group_range() == 1);
52+
Match &= (OpportunisticGroup.get_local_linear_range() <=
53+
SG.get_local_linear_range());
54+
MatchAcc[WI] = Match;
55+
LeaderAcc[WI] = OpportunisticGroup.leader();
56+
}
57+
};
58+
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
59+
});
60+
61+
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
62+
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
63+
uint32_t NumLeaders = 0;
64+
for (int WI = 0; WI < WGS; ++WI) {
65+
if (WI % 2 == 0) {
66+
assert(MatchAcc[WI] == true);
67+
if (LeaderAcc[WI]) {
68+
NumLeaders++;
69+
}
6470
}
6571
}
72+
assert(NumLeaders > 0);
6673
}
67-
assert(NumLeaders > 0);
6874
return 0;
6975
}

sycl/test-e2e/NonUniformGroups/tangle_group.cpp

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,59 @@ int main() {
2020
return 0;
2121
}
2222

23-
sycl::buffer<bool, 1> MatchBuf{sycl::range{32}};
24-
sycl::buffer<bool, 1> LeaderBuf{sycl::range{32}};
23+
// Test for both the full sub-group size and a case with less work than a full
24+
// sub-group.
25+
for (size_t WGS : std::array<size_t, 2>{32, 16}) {
26+
std::cout << "Testing for work size " << WGS << std::endl;
2527

26-
const auto NDR = sycl::nd_range<1>{32, 32};
27-
Q.submit([&](sycl::handler &CGH) {
28-
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
29-
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
30-
const auto KernelFunc =
31-
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
32-
auto WI = item.get_global_id();
33-
auto SG = item.get_sub_group();
28+
sycl::buffer<bool, 1> MatchBuf{sycl::range{WGS}};
29+
sycl::buffer<bool, 1> LeaderBuf{sycl::range{WGS}};
3430

35-
// Split into odd and even work-items via control flow.
36-
// Branches deliberately duplicated to test impact of optimizations.
37-
// This only reliably works with optimizations disabled right now.
38-
if (item.get_global_id() % 2 == 0) {
39-
auto TangleGroup = syclex::get_tangle_group(SG);
31+
const auto NDR = sycl::nd_range<1>{WGS, WGS};
32+
Q.submit([&](sycl::handler &CGH) {
33+
sycl::accessor MatchAcc{MatchBuf, CGH, sycl::write_only};
34+
sycl::accessor LeaderAcc{LeaderBuf, CGH, sycl::write_only};
35+
const auto KernelFunc =
36+
[=](sycl::nd_item<1> item) [[sycl::reqd_sub_group_size(32)]] {
37+
auto WI = item.get_global_id();
38+
auto SG = item.get_sub_group();
4039

41-
bool Match = true;
42-
Match &= (TangleGroup.get_group_id() == 0);
43-
Match &= (TangleGroup.get_local_id() == SG.get_local_id() / 2);
44-
Match &= (TangleGroup.get_group_range() == 1);
45-
Match &= (TangleGroup.get_local_range() == 16);
46-
MatchAcc[WI] = Match;
47-
LeaderAcc[WI] = TangleGroup.leader();
48-
} else {
49-
auto TangleGroup = syclex::get_tangle_group(SG);
40+
// Split into odd and even work-items via control flow.
41+
// Branches deliberately duplicated to test impact of optimizations.
42+
// This only reliably works with optimizations disabled right now.
43+
if (item.get_global_id() % 2 == 0) {
44+
auto TangleGroup = syclex::get_tangle_group(SG);
5045

51-
bool Match = true;
52-
Match &= (TangleGroup.get_group_id() == 0);
53-
Match &= (TangleGroup.get_local_id() == SG.get_local_id() / 2);
54-
Match &= (TangleGroup.get_group_range() == 1);
55-
Match &= (TangleGroup.get_local_range() == 16);
56-
MatchAcc[WI] = Match;
57-
LeaderAcc[WI] = TangleGroup.leader();
58-
}
59-
};
60-
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
61-
});
46+
bool Match = true;
47+
Match &= (TangleGroup.get_group_id() == 0);
48+
Match &= (TangleGroup.get_local_id() == SG.get_local_id() / 2);
49+
Match &= (TangleGroup.get_group_range() == 1);
50+
Match &= (TangleGroup.get_local_range() ==
51+
SG.get_local_linear_range() / 2);
52+
MatchAcc[WI] = Match;
53+
LeaderAcc[WI] = TangleGroup.leader();
54+
} else {
55+
auto TangleGroup = syclex::get_tangle_group(SG);
6256

63-
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
64-
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
65-
for (int WI = 0; WI < 32; ++WI) {
66-
assert(MatchAcc[WI] == true);
67-
assert(LeaderAcc[WI] == (WI < 2));
57+
bool Match = true;
58+
Match &= (TangleGroup.get_group_id() == 0);
59+
Match &= (TangleGroup.get_local_id() == SG.get_local_id() / 2);
60+
Match &= (TangleGroup.get_group_range() == 1);
61+
Match &= (TangleGroup.get_local_range() ==
62+
SG.get_local_linear_range() / 2);
63+
MatchAcc[WI] = Match;
64+
LeaderAcc[WI] = TangleGroup.leader();
65+
}
66+
};
67+
CGH.parallel_for<TestKernel>(NDR, KernelFunc);
68+
});
69+
70+
sycl::host_accessor MatchAcc{MatchBuf, sycl::read_only};
71+
sycl::host_accessor LeaderAcc{LeaderBuf, sycl::read_only};
72+
for (int WI = 0; WI < WGS; ++WI) {
73+
assert(MatchAcc[WI] == true);
74+
assert(LeaderAcc[WI] == (WI < 2));
75+
}
6876
}
6977
return 0;
7078
}

0 commit comments

Comments
 (0)