Skip to content

Commit 858db62

Browse files
authored
[SYCL][Fusion] Abort fusion on non-uniform work-group sizes ND-range (#12077)
Some ND-ranges combinations may result in non-uniform work-group sizes in the fused ND-range, e.g., fusing `{{9}, {3}}` and `{512}` would yield `{{512}, {3}}`. Abort fusion in these cases. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 3546ab8 commit 858db62

File tree

3 files changed

+48
-19
lines changed

3 files changed

+48
-19
lines changed

sycl-fusion/common/lib/NDRangesHelper.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ bool jit_compiler::isHeterogeneousList(ArrayRef<NDRange> NDRanges) {
8787
return any_of(NDRanges, [&ND](const auto &Other) { return ND != Other; });
8888
}
8989

90+
static bool wouldYieldUniformWorkGroupSize(const Indices &LocalSize,
91+
llvm::ArrayRef<NDRange> NDRanges) {
92+
const auto GlobalSize = getMaximalGlobalSize(NDRanges);
93+
return llvm::all_of(llvm::zip_equal(GlobalSize, LocalSize),
94+
[](const std::tuple<std::size_t, std::size_t> &P) {
95+
const auto &[GlobalSize, LocalSize] = P;
96+
return GlobalSize % LocalSize == 0;
97+
});
98+
}
99+
90100
bool jit_compiler::isValidCombination(llvm::ArrayRef<NDRange> NDRanges) {
91101
if (NDRanges.empty()) {
92102
return false;
@@ -95,9 +105,14 @@ bool jit_compiler::isValidCombination(llvm::ArrayRef<NDRange> NDRanges) {
95105
const auto &ND = FirstSpecifiedLocalSize == NDRanges.end()
96106
? NDRanges.front()
97107
: *FirstSpecifiedLocalSize;
98-
return llvm::all_of(NDRanges, [&ND](const auto &Other) {
99-
return compatibleRanges(ND, Other);
100-
});
108+
return llvm::all_of(NDRanges,
109+
[&ND](const auto &Other) {
110+
return compatibleRanges(ND, Other);
111+
}) &&
112+
// Either no local size is specified or the maximal global size is
113+
// compatible with the specified local size.
114+
(FirstSpecifiedLocalSize == NDRanges.end() ||
115+
wouldYieldUniformWorkGroupSize(ND.getLocalSize(), NDRanges));
101116
}
102117

103118
bool jit_compiler::requireIDRemapping(const NDRange &LHS, const NDRange &RHS) {

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ FusionResult KernelFusion::fuseKernels(
8787

8888
if (!isValidCombination(NDRanges)) {
8989
return FusionResult{
90-
"Cannot fuse kernels with different offsets or local sizes or "
91-
"different global sizes in dimensions [2, N) and non-zero offsets"};
90+
"Cannot fuse kernels with different offsets or local sizes, or "
91+
"different global sizes in dimensions [2, N) and non-zero offsets, "
92+
"or those whose fusion would yield non-uniform work-groups sizes"};
9293
}
9394

9495
bool IsHeterogeneousList = jit_compiler::isHeterogeneousList(NDRanges);

sycl/test-e2e/KernelFusion/abort_fusion.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@ constexpr size_t dataSize = 512;
1414

1515
enum class Internalization { None, Local, Private };
1616

17-
template <typename Kernel1Name, typename Kernel2Name, int Kernel1Dim>
18-
void performFusion(queue &q, range<Kernel1Dim> k1Global,
19-
range<Kernel1Dim> k1Local) {
17+
template <typename Range> size_t getSize(Range r);
18+
19+
template <> size_t getSize(range<1> r) { return r.size(); }
20+
template <> size_t getSize(nd_range<1> r) {
21+
return r.get_global_range().size();
22+
}
23+
24+
template <typename Kernel1Name, typename Kernel2Name, typename Range1,
25+
typename Range2>
26+
void performFusion(queue &q, Range1 R1, Range2 R2) {
2027
int in[dataSize], tmp[dataSize], out[dataSize];
2128

2229
for (size_t i = 0; i < dataSize; ++i) {
@@ -37,19 +44,15 @@ void performFusion(queue &q, range<Kernel1Dim> k1Global,
3744
q.submit([&](handler &cgh) {
3845
auto accIn = bIn.get_access(cgh);
3946
auto accTmp = bTmp.get_access(cgh);
40-
cgh.parallel_for<Kernel1Name>(nd_range<Kernel1Dim>{k1Global, k1Local},
41-
[=](item<Kernel1Dim> i) {
42-
auto LID = i.get_linear_id();
43-
accTmp[LID] = accIn[LID] + 5;
44-
});
47+
cgh.parallel_for<Kernel1Name>(
48+
R1, [=](item<1> i) { accTmp[i] = accIn[i] + 5; });
4549
});
4650

4751
q.submit([&](handler &cgh) {
4852
auto accTmp = bTmp.get_access(cgh);
4953
auto accOut = bOut.get_access(cgh);
50-
cgh.parallel_for<Kernel2Name>(nd_range<1>{{dataSize}, {8}}, [=](id<1> i) {
51-
accOut[i] = accTmp[i] * 2;
52-
});
54+
cgh.parallel_for<Kernel2Name>(
55+
R2, [=](id<1> i) { accOut[i] = accTmp[i] * 2; });
5356
});
5457

5558
fw.complete_fusion({ext::codeplay::experimental::property::no_barriers{}});
@@ -60,7 +63,8 @@ void performFusion(queue &q, range<Kernel1Dim> k1Global,
6063

6164
// Check the results
6265
size_t numErrors = 0;
63-
for (size_t i = 0; i < k1Global.size(); ++i) {
66+
size_t size = getSize(R1);
67+
for (size_t i = 0; i < size; ++i) {
6468
if (out[i] != ((i + 5) * 2)) {
6569
++numErrors;
6670
}
@@ -89,8 +93,9 @@ int main() {
8993

9094
// Scenario: Fusing two kernels with different local size should lead to
9195
// fusion being aborted.
92-
performFusion<class Kernel1_3, class Kernel2_3>(q, range<1>{dataSize},
93-
range<1>{16});
96+
performFusion<class Kernel1_3, class Kernel2_3>(
97+
q, nd_range<1>{range<1>{dataSize}, range<1>{16}},
98+
nd_range<1>{range<1>{dataSize}, range<1>{8}});
9499
// CHECK: ERROR: JIT compilation for kernel fusion failed with message:
95100
// CHECK-NEXT: Cannot fuse kernels with different offsets or local sizes
96101
// CHECK: COMPUTATION OK
@@ -101,5 +106,13 @@ int main() {
101106
// CHECK-NOT: Cannot fuse kernels with different offsets or local sizes
102107
// CHECK: WARNING: Fusion list is empty
103108

109+
// Scenario: Fusing two kernels that would lead to non-uniform work-group
110+
// sizes should lead to fusion being aborted.
111+
performFusion<class Kernel1_4, class Kernel2_4>(
112+
q, nd_range<1>{range<1>{9}, range<1>{3}}, range<1>{dataSize});
113+
// CHECK: ERROR: JIT compilation for kernel fusion failed with message:
114+
// CHECK-NEXT: Cannot fuse kernels with different offsets or local sizes
115+
// CHECK: COMPUTATION OK
116+
104117
return 0;
105118
}

0 commit comments

Comments
 (0)