Skip to content

Commit 69831f6

Browse files
committed
fix block sorting
Basic block ordering was not completely correct, as it could end up placing a post-dominating block before the block it post-dominates. Using the partial ordering is more correct, but this was a bit unstable due to the unspecified order when 2 blocks shared the same rank. Fixed the iterator to be stable. Signed-off-by: Nathan Gauër <[email protected]>
1 parent ee9bb95 commit 69831f6

35 files changed

+815
-790
lines changed

llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,23 +175,6 @@ void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
175175
visit(MF, *MF.begin(), op);
176176
}
177177

178-
// Sorts basic blocks by dominance to respect the SPIR-V spec.
179-
void sortBlocks(MachineFunction &MF) {
180-
MachineDominatorTree MDT(MF);
181-
182-
std::unordered_map<MachineBasicBlock *, size_t> Order;
183-
Order.reserve(MF.size());
184-
185-
size_t Index = 0;
186-
visit(MF, [&Order, &Index](MachineBasicBlock *MBB) { Order[MBB] = Index++; });
187-
188-
auto Comparator = [&Order](MachineBasicBlock &LHS, MachineBasicBlock &RHS) {
189-
return Order[&LHS] < Order[&RHS];
190-
};
191-
192-
MF.sort(Comparator);
193-
}
194-
195178
bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
196179
// Initialize the type registry.
197180
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
@@ -200,7 +183,6 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
200183
MachineIRBuilder MIB(MF);
201184

202185
processNewInstrs(MF, GR, MIB);
203-
sortBlocks(MF);
204186

205187
return true;
206188
}

llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp

Lines changed: 91 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,18 @@ using Edge = std::pair<BasicBlock *, BasicBlock *>;
7171
class PartialOrderingVisitor {
7272
DomTreeBuilder::BBDomTree DT;
7373
LoopInfo LI;
74-
BlockSet Visited;
75-
std::unordered_map<BasicBlock *, size_t> B2R;
76-
std::vector<std::pair<BasicBlock *, size_t>> Order;
74+
BlockSet Visited = {};
75+
76+
struct OrderInfo {
77+
size_t Rank;
78+
size_t TraversalIndex;
79+
};
80+
81+
using BlockToOrderInfoMap = std::unordered_map<BasicBlock *, OrderInfo>;
82+
BlockToOrderInfoMap BlockToOrder;
83+
84+
// std::unordered_map<BasicBlock *, std::pair<size_t, size_t>> B2R = {};
85+
std::vector<BasicBlock *> Order = {};
7786

7887
// Get all basic-blocks reachable from Start.
7988
BlockSet getReachableFrom(BasicBlock *Start) {
@@ -106,18 +115,19 @@ class PartialOrderingVisitor {
106115
Loop *L = LI.getLoopFor(BB);
107116
const bool isLoopHeader = LI.isLoopHeader(BB);
108117

109-
if (B2R.count(BB) == 0) {
110-
B2R.emplace(BB, Rank);
118+
if (BlockToOrder.count(BB) == 0) {
119+
OrderInfo Info = {Rank, Visited.size()};
120+
BlockToOrder.emplace(BB, Info);
111121
} else {
112-
B2R[BB] = std::max(B2R[BB], Rank);
122+
BlockToOrder[BB].Rank = std::max(BlockToOrder[BB].Rank, Rank);
113123
}
114124

115125
for (BasicBlock *Predecessor : predecessors(BB)) {
116126
if (isLoopHeader && L->contains(Predecessor)) {
117127
continue;
118128
}
119129

120-
if (B2R.count(Predecessor) == 0) {
130+
if (BlockToOrder.count(Predecessor) == 0) {
121131
return Rank;
122132
}
123133
}
@@ -155,45 +165,56 @@ class PartialOrderingVisitor {
155165

156166
visit(&*F.begin(), 0);
157167

158-
for (auto &[BB, Rank] : B2R)
159-
Order.emplace_back(BB, Rank);
168+
Order.reserve(F.size());
169+
for (auto &[BB, Info] : BlockToOrder)
170+
Order.emplace_back(BB);
160171

161-
std::sort(Order.begin(), Order.end(), [](const auto &LHS, const auto &RHS) {
162-
return LHS.second < RHS.second;
163-
});
164-
165-
for (size_t i = 0; i < Order.size(); i++)
166-
B2R[Order[i].first] = i;
172+
std::sort(
173+
Order.begin(), Order.end(),
174+
[&](const auto &LHS, const auto &RHS) { return compare(LHS, RHS); });
167175
}
168176

169-
size_t getRank(BasicBlock *BB) {
170-
return B2R[BB];
177+
bool compare(const BasicBlock *LHS, const BasicBlock *RHS) const {
178+
const OrderInfo &InfoLHS = BlockToOrder.at(const_cast<BasicBlock *>(LHS));
179+
const OrderInfo &InfoRHS = BlockToOrder.at(const_cast<BasicBlock *>(RHS));
180+
if (InfoLHS.Rank != InfoRHS.Rank)
181+
return InfoLHS.Rank < InfoRHS.Rank;
182+
return InfoLHS.TraversalIndex < InfoRHS.TraversalIndex;
171183
}
172184

173185
// Visit the function starting from the basic block |Start|, and calling |Op|
174186
// on each visited BB. This traversal ignores back-edges, meaning this won't
175187
// visit a node to which |Start| is not an ancestor.
188+
// If Op returns |true|, the visitor continues. If |Op| returns false, the
189+
// visitor will stop at that rank. This means if 2 nodes share the same rank,
190+
// and Op returns false when visiting the first, the second will be visited
191+
// afterwards. But none of their successors will.
176192
void partialOrderVisit(BasicBlock &Start,
177193
std::function<bool(BasicBlock *)> Op) {
178194
BlockSet Reachable = getReachableFrom(&Start);
179-
assert(B2R.count(&Start) != 0);
180-
size_t Rank = Order[B2R[&Start]].second;
195+
assert(BlockToOrder.count(&Start) != 0);
181196

197+
// Skipping blocks with a rank inferior to |Start|'s rank.
182198
auto It = Order.begin();
183-
while (It != Order.end() && It->second < Rank)
199+
while (It != Order.end() && *It != &Start)
184200
++It;
185201

186-
if (It == Order.end())
187-
return;
202+
// This is unexpected. Worst case |Start| is the last block,
203+
// so It should point to the last block, not past-end.
204+
assert(It != Order.end());
188205

189-
size_t EndRank = Order.rbegin()->second + 1;
190-
for (; It != Order.end() && It->second <= EndRank; ++It) {
191-
if (Reachable.count(It->first) == 0) {
206+
// By default, there is no rank limit. Setting it to the maximum value.
207+
std::optional<size_t> EndRank = std::nullopt;
208+
for (; It != Order.end(); ++It) {
209+
if (EndRank.has_value() && BlockToOrder[*It].Rank > *EndRank)
210+
break;
211+
212+
if (Reachable.count(*It) == 0) {
192213
continue;
193214
}
194215

195-
if (!Op(It->first)) {
196-
EndRank = It->second;
216+
if (!Op(*It)) {
217+
EndRank = BlockToOrder[*It].Rank;
197218
}
198219
}
199220
}
@@ -641,7 +662,6 @@ class SPIRVStructurizer : public FunctionPass {
641662
auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
642663
IRBuilder<> ExitBuilder(NewExit);
643664

644-
BlockSet SeenDst;
645665
std::vector<BasicBlock *> Dsts;
646666
std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
647667

@@ -846,10 +866,10 @@ class SPIRVStructurizer : public FunctionPass {
846866
std::sort(MergeInstructions.begin(), MergeInstructions.end(),
847867
[&Visitor](Instruction *Left, Instruction *Right) {
848868
if (Left == Right)
849-
return true;
869+
return false;
850870
BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
851871
BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
852-
return Visitor.getRank(RightMerge) >= Visitor.getRank(LeftMerge);
872+
return !Visitor.compare(RightMerge, LeftMerge);
853873
});
854874

855875
for (Instruction *I : MergeInstructions) {
@@ -1041,8 +1061,6 @@ class SPIRVStructurizer : public FunctionPass {
10411061
assert(Node->Parent->Header && Node->Parent->Merge);
10421062

10431063
BlockSet ConstructBlocks = getConstructBlocks(S, Node);
1044-
BlockSet ParentBlocks = getConstructBlocks(S, Node->Parent);
1045-
10461064
auto Edges = getExitsFrom(ConstructBlocks, *Node->Header);
10471065

10481066
// No edges exiting the construct.
@@ -1300,6 +1318,44 @@ class SPIRVStructurizer : public FunctionPass {
13001318
return Modified;
13011319
}
13021320

1321+
// Sort blocks in a partial ordering, so each block is after all its
1322+
// dominators. This should match both the SPIR-V and the MIR requirements.
1323+
bool sortBlocks(Function &F) {
1324+
if (F.size() == 0)
1325+
return false;
1326+
1327+
bool Modified = false;
1328+
1329+
std::vector<BasicBlock *> Order;
1330+
Order.reserve(F.size());
1331+
1332+
PartialOrderingVisitor Visitor(F);
1333+
Visitor.partialOrderVisit(*F.begin(), [&Order](BasicBlock *Block) {
1334+
Order.push_back(Block);
1335+
return true;
1336+
});
1337+
1338+
assert(&*F.begin() == Order[0]);
1339+
BasicBlock *LastBlock = &*F.begin();
1340+
for (BasicBlock *BB : Order) {
1341+
if (BB != LastBlock && &*LastBlock->getNextNode() != BB) {
1342+
Modified = true;
1343+
BB->moveAfter(LastBlock);
1344+
}
1345+
LastBlock = BB;
1346+
}
1347+
#if 0
1348+
for (auto It = Order.begin() + 1; It != Order.end(); ++It) {
1349+
if (*It != &*LastBlock->getNextNode()) {
1350+
Modified = true;
1351+
(*It)->moveAfter(LastBlock);
1352+
}
1353+
LastBlock = *It;
1354+
}
1355+
#endif
1356+
return Modified;
1357+
}
1358+
13031359
public:
13041360
static char ID;
13051361

@@ -1367,6 +1423,9 @@ class SPIRVStructurizer : public FunctionPass {
13671423
// branches with 1 or 2 returning edges. Adding a header for those.
13681424
Modified |= addHeaderToRemainingDivergentDAG(F);
13691425

1426+
// STEP 9: sort basic blocks to match both the LLVM & SPIR-V requirements.
1427+
Modified |= sortBlocks(F);
1428+
13701429
return Modified;
13711430
}
13721431

llvm/test/CodeGen/SPIRV/branching/if-merging.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ merge_label:
3737
; CHECK: [[COND:%.+]] = OpIEqual [[BOOL]] [[A]] [[B]]
3838
; CHECK: OpBranchConditional [[COND]] [[TRUE_LABEL:%.+]] [[FALSE_LABEL:%.+]]
3939

40+
; CHECK: [[TRUE_LABEL]] = OpLabel
41+
; CHECK: [[V1:%.+]] = OpFunctionCall [[I32]] [[FOO]]
42+
; CHECK: OpBranch [[MERGE_LABEL:%.+]]
43+
4044
; CHECK: [[FALSE_LABEL]] = OpLabel
4145
; CHECK: [[V2:%.+]] = OpFunctionCall [[I32]] [[BAR]]
42-
; CHECK: OpBranch [[MERGE_LABEL:%.+]]
46+
; CHECK: OpBranch [[MERGE_LABEL]]
4347

4448
; CHECK: [[MERGE_LABEL]] = OpLabel
45-
; CHECK-NEXT: [[V:%.+]] = OpPhi [[I32]] [[V1:%.+]] [[TRUE_LABEL]] [[V2]] [[FALSE_LABEL]]
49+
; CHECK-NEXT: [[V:%.+]] = OpPhi [[I32]] [[V1]] [[TRUE_LABEL]] [[V2]] [[FALSE_LABEL]]
4650
; CHECK: OpReturnValue [[V]]
4751

48-
; CHECK: [[TRUE_LABEL]] = OpLabel
49-
; CHECK: [[V1]] = OpFunctionCall [[I32]] [[FOO]]
50-
; CHECK: OpBranch [[MERGE_LABEL]]
51-
5252
; CHECK-NEXT: OpFunctionEnd

llvm/test/CodeGen/SPIRV/branching/if-non-merging.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ false_label:
2121
; CHECK: [[ENTRY:%.+]] = OpLabel
2222
; CHECK: [[COND:%.+]] = OpIEqual [[BOOL]] [[A]] [[B]]
2323
; CHECK: OpBranchConditional [[COND]] [[TRUE_LABEL:%.+]] [[FALSE_LABEL:%.+]]
24-
; CHECK: [[FALSE_LABEL]] = OpLabel
25-
; CHECK: OpReturnValue [[FALSE]]
2624
; CHECK: [[TRUE_LABEL]] = OpLabel
2725
; CHECK: OpReturnValue [[TRUE]]
26+
; CHECK: [[FALSE_LABEL]] = OpLabel
27+
; CHECK: OpReturnValue [[FALSE]]

llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
; CHECK: OpFunction
55
; CHECK: OpBranchConditional %[[#]] %[[#if_then:]] %[[#if_end:]]
6+
; CHECK: %[[#if_then]] = OpLabel
7+
; CHECK: OpBranch %[[#if_end]]
68
; CHECK: %[[#if_end]] = OpLabel
79
; CHECK: %[[#Var:]] = OpPhi
810
; CHECK: OpSwitch %[[#Var]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]]
911
; CHECK-COUNT-11: OpLabel
1012
; CHECK-NOT: OpBranch
1113
; CHECK: OpReturn
12-
; CHECK: %[[#if_then]] = OpLabel
13-
; CHECK: OpBranch %[[#if_end]]
1414
; CHECK-NEXT: OpFunctionEnd
1515

1616
define spir_func void @foo(i64 noundef %addr, i64 noundef %as) {

llvm/test/CodeGen/SPIRV/phi-ptrcast-dominate.ll

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ entry:
2323
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#l2:]]
2424
br i1 %b1, label %l1, label %l2
2525

26+
; CHECK: %[[#l1]] = OpLabel
27+
; CHECK-NEXT: OpPhi
28+
; CHECK: OpBranch %[[#exit:]]
2629
l1:
2730
%str = phi ptr addrspace(1) [ @.str.1, %entry ], [ @.str.2, %l2 ], [ @.str.2, %l3 ]
2831
br label %exit
2932

3033
; CHECK: %[[#l2]] = OpLabel
31-
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#l3:]]
34+
; CHECK: OpBranchConditional %[[#]] %[[#l1]] %[[#l3:]]
3235
l2:
3336
br i1 %b2, label %l1, label %l3
3437

3538
; CHECK: %[[#l3]] = OpLabel
36-
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#exit:]]
39+
; CHECK: OpBranchConditional %[[#]] %[[#l1]] %[[#exit]]
3740
l3:
3841
br i1 %b3, label %l1, label %exit
3942

@@ -42,9 +45,6 @@ l3:
4245
exit:
4346
ret void
4447

45-
; CHECK: %[[#l1]] = OpLabel
46-
; CHECK-NEXT: OpPhi
47-
; CHECK: OpBranch %[[#exit:]]
4848
}
4949

5050
; CHECK: %[[#Case2]] = OpFunction
@@ -53,6 +53,9 @@ entry:
5353
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#l2:]]
5454
br i1 %b1, label %l1, label %l2
5555

56+
; CHECK: %[[#l1]] = OpLabel
57+
; CHECK-NEXT: OpPhi
58+
; CHECK: OpBranch %[[#exit:]]
5659
l1:
5760
%str = phi ptr addrspace(1) [ %str1, %entry ], [ %str2, %l2 ], [ %str2, %l3 ]
5861
br label %exit
@@ -75,10 +78,14 @@ exit:
7578

7679
; CHECK: %[[#Case3]] = OpFunction
7780
define spir_func void @case3(i1 %b1, i1 %b2, i1 %b3, ptr addrspace(1) byval(%struct1) %_arg_str1, ptr addrspace(1) byval(%struct2) %_arg_str2) {
81+
82+
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#l2:]]
7883
entry:
7984
br i1 %b1, label %l1, label %l2
8085

81-
; CHECK: OpBranchConditional %[[#]] %[[#l1:]] %[[#l2:]]
86+
; CHECK: %[[#l1]] = OpLabel
87+
; CHECK-NEXT: OpPhi
88+
; CHECK: OpBranch %[[#exit:]]
8289
l1:
8390
%str = phi ptr addrspace(1) [ %_arg_str1, %entry ], [ %str2, %l2 ], [ %str3, %l3 ]
8491
br label %exit
@@ -101,8 +108,4 @@ l3:
101108
; CHECK: OpReturn
102109
exit:
103110
ret void
104-
105-
; CHECK: %[[#l1]] = OpLabel
106-
; CHECK-NEXT: OpPhi
107-
; CHECK: OpBranch %[[#exit:]]
108111
}

0 commit comments

Comments
 (0)