Skip to content

Commit 00ce1ff

Browse files
committed
[SPIR-V] Fix block sorting with irreducible CFG
Block sorting was assuming reducible CFG. Meaning we always had a best node to continue with. Irreducible CFG makes breaks this assumption, so the algorithm looped indefinitely because no node was a valid candidate. Fixes #116692 Signed-off-by: Nathan Gauër <[email protected]>
1 parent 8a6a76b commit 00ce1ff

File tree

4 files changed

+301
-21
lines changed

4 files changed

+301
-21
lines changed

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,12 @@ size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
525525
continue;
526526

527527
auto Iterator = BlockToOrder.end();
528+
// This block hasn't been ranked yet. Ignoring.
529+
// This doesn't happen often, but when dealing with irreducible CFG, we have
530+
// to rank nodes without knowing the rank of all their predecessors.
531+
if (Iterator == BlockToOrder.end())
532+
continue;
533+
528534
Loop *L = LI.getLoopFor(P);
529535
BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
530536

@@ -550,15 +556,27 @@ size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
550556
ToVisit.push(BB);
551557
Queued.insert(BB);
552558

559+
// When the graph is irreducible, we can end up in a case where each
560+
// node has a predecessor we haven't ranked yet.
561+
// When such case arise, we have to pick a node to continue.
562+
// This index is used to determine when we looped through all candidates.
563+
// Each time a candidate is processed, this counter is reset.
564+
// If the index is larger than the queue size, it means we looped.
565+
size_t QueueIndex = 0;
566+
553567
while (ToVisit.size() != 0) {
554568
BasicBlock *BB = ToVisit.front();
555569
ToVisit.pop();
556570

557-
if (!CanBeVisited(BB)) {
571+
// Either the node is a candidate, or we looped already, and this is
572+
// the first node we tried.
573+
if (!CanBeVisited(BB) && QueueIndex <= ToVisit.size()) {
558574
ToVisit.push(BB);
575+
QueueIndex++;
559576
continue;
560577
}
561578

579+
QueueIndex = 0;
562580
size_t Rank = GetNodeRank(BB);
563581
OrderInfo Info = {Rank, BlockToOrder.size()};
564582
BlockToOrder.emplace(BB, Info);

llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,28 @@
3434
; CHECK: %[[#bb30:]] = OpLabel
3535
; CHECK: OpSelectionMerge %[[#bb31:]] None
3636
; CHECK: OpBranchConditional %[[#]] %[[#bb32:]] %[[#bb33:]]
37-
; CHECK: %[[#bb32:]] = OpLabel
37+
; CHECK: %[[#bb32]] = OpLabel
3838
; CHECK: OpSelectionMerge %[[#bb34:]] None
39-
; CHECK: OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34:]]
40-
; CHECK: %[[#bb33:]] = OpLabel
39+
; CHECK: OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34]]
40+
; CHECK: %[[#bb33]] = OpLabel
4141
; CHECK: OpSelectionMerge %[[#bb36:]] None
4242
; CHECK: OpBranchConditional %[[#]] %[[#bb37:]] %[[#bb38:]]
43-
; CHECK: %[[#bb35:]] = OpLabel
44-
; CHECK: OpBranch %[[#bb34:]]
45-
; CHECK: %[[#bb37:]] = OpLabel
46-
; CHECK: OpBranch %[[#bb36:]]
47-
; CHECK: %[[#bb38:]] = OpLabel
43+
; CHECK: %[[#bb35]] = OpLabel
44+
; CHECK: OpBranch %[[#bb34]]
45+
; CHECK: %[[#bb34]] = OpLabel
46+
; CHECK: OpBranch %[[#bb31]]
47+
; CHECK: %[[#bb37]] = OpLabel
48+
; CHECK: OpBranch %[[#bb36]]
49+
; CHECK: %[[#bb38]] = OpLabel
4850
; CHECK: OpSelectionMerge %[[#bb39:]] None
49-
; CHECK: OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39:]]
50-
; CHECK: %[[#bb34:]] = OpLabel
51-
; CHECK: OpBranch %[[#bb31:]]
52-
; CHECK: %[[#bb40:]] = OpLabel
53-
; CHECK: OpBranch %[[#bb39:]]
54-
; CHECK: %[[#bb39:]] = OpLabel
55-
; CHECK: OpBranch %[[#bb36:]]
56-
; CHECK: %[[#bb36:]] = OpLabel
57-
; CHECK: OpBranch %[[#bb31:]]
58-
; CHECK: %[[#bb31:]] = OpLabel
51+
; CHECK: OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39]]
52+
; CHECK: %[[#bb40]] = OpLabel
53+
; CHECK: OpBranch %[[#bb39]]
54+
; CHECK: %[[#bb39]] = OpLabel
55+
; CHECK: OpBranch %[[#bb36]]
56+
; CHECK: %[[#bb36]] = OpLabel
57+
; CHECK: OpBranch %[[#bb31]]
58+
; CHECK: %[[#bb31]] = OpLabel
5959
; CHECK: OpReturnValue %[[#]]
6060
; CHECK: OpFunctionEnd
6161
; CHECK: %[[#func_26:]] = OpFunction %[[#void:]] DontInline %[[#]]

llvm/unittests/Target/SPIRV/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ set(LLVM_LINK_COMPONENTS
1515

1616
add_llvm_target_unittest(SPIRVTests
1717
SPIRVConvergenceRegionAnalysisTests.cpp
18+
SPIRVSortBlocksTests.cpp
1819
SPIRVAPITest.cpp
19-
)
20-
20+
)
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
//===- SPIRVSortBlocksTests.cpp ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "SPIRVUtils.h"
10+
#include "llvm/Analysis/DominanceFrontier.h"
11+
#include "llvm/Analysis/PostDominators.h"
12+
#include "llvm/AsmParser/Parser.h"
13+
#include "llvm/IR/Instructions.h"
14+
#include "llvm/IR/LLVMContext.h"
15+
#include "llvm/IR/LegacyPassManager.h"
16+
#include "llvm/IR/Module.h"
17+
#include "llvm/IR/PassInstrumentation.h"
18+
#include "llvm/IR/Type.h"
19+
#include "llvm/IR/TypedPointerType.h"
20+
#include "llvm/Support/SourceMgr.h"
21+
22+
#include "gmock/gmock.h"
23+
#include "gtest/gtest.h"
24+
#include <queue>
25+
26+
using namespace llvm;
27+
using namespace llvm::SPIRV;
28+
29+
class SPIRVSortBlocksTest : public testing::Test {
30+
protected:
31+
void TearDown() override { M.reset(); }
32+
33+
bool run(StringRef Assembly) {
34+
assert(M == nullptr &&
35+
"Calling runAnalysis multiple times is unsafe. See getAnalysis().");
36+
37+
SMDiagnostic Error;
38+
M = parseAssemblyString(Assembly, Error, Context);
39+
assert(M && "Bad assembly. Bad test?");
40+
llvm::Function *F = M->getFunction("main");
41+
return sortBlocks(*F);
42+
}
43+
44+
void checkBasicBlockOrder(std::vector<const char *> &&Expected) {
45+
llvm::Function *F = M->getFunction("main");
46+
auto It = F->begin();
47+
for (const auto *Name : Expected) {
48+
ASSERT_TRUE(It != F->end())
49+
<< "Expected block \"" << Name
50+
<< "\" but reached the end of the function instead.";
51+
ASSERT_TRUE(It->getName() == Name)
52+
<< "Error: expected block \"" << Name << "\" got \"" << It->getName()
53+
<< "\"";
54+
It++;
55+
}
56+
EXPECT_TRUE(It == F->end());
57+
ASSERT_TRUE(It == F->end())
58+
<< "No more blocks were expected, but function has more.";
59+
}
60+
61+
protected:
62+
LLVMContext Context;
63+
std::unique_ptr<Module> M;
64+
};
65+
66+
TEST_F(SPIRVSortBlocksTest, DefaultRegion) {
67+
StringRef Assembly = R"(
68+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
69+
ret void
70+
}
71+
)";
72+
73+
EXPECT_FALSE(run(Assembly));
74+
}
75+
76+
TEST_F(SPIRVSortBlocksTest, BasicBlockSwap) {
77+
StringRef Assembly = R"(
78+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
79+
entry:
80+
br label %middle
81+
exit:
82+
ret void
83+
middle:
84+
br label %exit
85+
}
86+
)";
87+
88+
EXPECT_TRUE(run(Assembly));
89+
checkBasicBlockOrder({"entry", "middle", "exit"});
90+
}
91+
92+
// Simple loop:
93+
// entry -> header <-----------------+
94+
// | `-> body -> continue -+
95+
// `-> end
96+
TEST_F(SPIRVSortBlocksTest, LoopOrdering) {
97+
StringRef Assembly = R"(
98+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
99+
entry:
100+
%1 = icmp ne i32 0, 0
101+
br label %header
102+
end:
103+
ret void
104+
body:
105+
br label %continue
106+
continue:
107+
br label %header
108+
header:
109+
br i1 %1, label %body, label %end
110+
}
111+
)";
112+
113+
EXPECT_TRUE(run(Assembly));
114+
checkBasicBlockOrder({"entry", "header", "body", "continue", "end"});
115+
}
116+
117+
// Diamond condition:
118+
// +-> A -+
119+
// entry -+ +-> C
120+
// +-> B -+
121+
//
122+
// A and B order can be flipped with no effect, but it must be remain
123+
// deterministic/stable.
124+
TEST_F(SPIRVSortBlocksTest, DiamondCondition) {
125+
StringRef Assembly = R"(
126+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
127+
entry:
128+
%1 = icmp ne i32 0, 0
129+
br i1 %1, label %a, label %b
130+
c:
131+
ret void
132+
b:
133+
br label %c
134+
a:
135+
br label %c
136+
}
137+
)";
138+
139+
EXPECT_TRUE(run(Assembly));
140+
checkBasicBlockOrder({"entry", "a", "b", "c"});
141+
}
142+
143+
// Skip condition:
144+
// +-> A -+
145+
// entry -+ +-> C
146+
// +------+
147+
TEST_F(SPIRVSortBlocksTest, SkipCondition) {
148+
StringRef Assembly = R"(
149+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
150+
entry:
151+
%1 = icmp ne i32 0, 0
152+
br i1 %1, label %a, label %c
153+
c:
154+
ret void
155+
a:
156+
br label %c
157+
}
158+
)";
159+
160+
EXPECT_TRUE(run(Assembly));
161+
checkBasicBlockOrder({"entry", "a", "c"});
162+
}
163+
164+
// Crossing conditions:
165+
// +------+ +-> C -+
166+
// +-> A -+ | | |
167+
// entry -+ +--_|_-+ +-> E
168+
// +-> B -+ | |
169+
// +------+----> D -+
170+
//
171+
// A & B have the same rank.
172+
// C & D have the same rank, but are after A & B.
173+
// E if the last block.
174+
TEST_F(SPIRVSortBlocksTest, CrossingCondition) {
175+
StringRef Assembly = R"(
176+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
177+
entry:
178+
%1 = icmp ne i32 0, 0
179+
br i1 %1, label %a, label %b
180+
e:
181+
ret void
182+
c:
183+
br label %e
184+
b:
185+
br i1 %1, label %d, label %c
186+
d:
187+
br label %e
188+
a:
189+
br i1 %1, label %c, label %d
190+
}
191+
)";
192+
193+
EXPECT_TRUE(run(Assembly));
194+
checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
195+
}
196+
197+
// Irreducible CFG
198+
// digraph {
199+
// entry -> A;
200+
//
201+
// A -> B;
202+
// A -> C;
203+
//
204+
// B -> A;
205+
// B -> C;
206+
//
207+
// C -> B;
208+
// }
209+
//
210+
// Order starts with Entry and A. Order of B and C can change, but must remain
211+
// stable.
212+
TEST_F(SPIRVSortBlocksTest, IrreducibleOrdering) {
213+
StringRef Assembly = R"(
214+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
215+
entry:
216+
%1 = icmp ne i32 0, 0
217+
br label %a
218+
219+
b:
220+
br i1 %1, label %a, label %c
221+
222+
c:
223+
br label %b
224+
225+
a:
226+
br i1 %1, label %b, label %c
227+
228+
}
229+
)";
230+
231+
EXPECT_TRUE(run(Assembly));
232+
checkBasicBlockOrder({"entry", "a", "b", "c"});
233+
}
234+
235+
TEST_F(SPIRVSortBlocksTest, IrreducibleOrderingBeforeReduction) {
236+
StringRef Assembly = R"(
237+
define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
238+
entry:
239+
%1 = icmp ne i32 0, 0
240+
br label %a
241+
242+
c:
243+
br i1 %1, label %d, label %e
244+
245+
e:
246+
ret void
247+
248+
b:
249+
br i1 %1, label %c, label %d
250+
251+
a:
252+
br label %b
253+
254+
d:
255+
br i1 %1, label %b, label %c
256+
257+
}
258+
)";
259+
260+
EXPECT_TRUE(run(Assembly));
261+
checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
262+
}

0 commit comments

Comments
 (0)