Skip to content

Commit b604fcb

Browse files
committed
[runtime] Move prolog/epilog block to a post-simplify strategy
The runtime unroller will try to produce a non-loop if the unroll count is 2 and thus the prolog/epilog loop would only run at most one iteration. The old implementation did this by avoiding loop construction entirely. This patches instead constructs the trivial loop and then explicitly breaks the backedge and simplifies. This does result in some additional code churn when triggered, but a) results in better quality code and b) removes a codepath which didn't work properly for multiple exit epilogs. One oddity that I want to draw to reviewer attention is that this somehow changes revisit order. The new order looks equivalent to me, but I don't understand how creating and erasing an extra loop here creates this effect. Differential Revision: https://reviews.llvm.org/D108521
1 parent 9b45fd9 commit b604fcb

File tree

7 files changed

+210
-322
lines changed

7 files changed

+210
-322
lines changed

llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "llvm/ADT/SmallPtrSet.h"
2424
#include "llvm/ADT/Statistic.h"
25+
#include "llvm/Analysis/InstructionSimplify.h"
2526
#include "llvm/Analysis/LoopIterator.h"
2627
#include "llvm/Analysis/ScalarEvolution.h"
2728
#include "llvm/IR/BasicBlock.h"
@@ -35,6 +36,7 @@
3536
#include "llvm/Transforms/Utils.h"
3637
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
3738
#include "llvm/Transforms/Utils/Cloning.h"
39+
#include "llvm/Transforms/Utils/Local.h"
3840
#include "llvm/Transforms/Utils/LoopUtils.h"
3941
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
4042
#include "llvm/Transforms/Utils/UnrollLoop.h"
@@ -299,17 +301,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
299301
PreserveLCSSA);
300302
}
301303

302-
/// Create a clone of the blocks in a loop and connect them together.
303-
/// If CreateRemainderLoop is false, loop structure will not be cloned,
304-
/// otherwise a new loop will be created including all cloned blocks, and the
305-
/// iterator of it switches to count NewIter down to 0.
304+
/// Create a clone of the blocks in a loop and connect them together. A new
305+
/// loop will be created including all cloned blocks, and the iterator of the
306+
/// new loop switched to count NewIter down to 0.
306307
/// The cloned blocks should be inserted between InsertTop and InsertBot.
307-
/// If loop structure is cloned InsertTop should be new preheader, InsertBot
308-
/// new loop exit.
309-
/// Return the new cloned loop that is created when CreateRemainderLoop is true.
308+
/// InsertTop should be new preheader, InsertBot new loop exit.
309+
/// Returns the new cloned loop that is created.
310310
static Loop *
311-
CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop,
312-
const bool UseEpilogRemainder, const bool UnrollRemainder,
311+
CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
312+
const bool UnrollRemainder,
313313
BasicBlock *InsertTop,
314314
BasicBlock *InsertBot, BasicBlock *Preheader,
315315
std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks,
@@ -323,20 +323,14 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop,
323323
Loop *ParentLoop = L->getParentLoop();
324324
NewLoopsMap NewLoops;
325325
NewLoops[ParentLoop] = ParentLoop;
326-
if (!CreateRemainderLoop)
327-
NewLoops[L] = ParentLoop;
328326

329327
// For each block in the original loop, create a new copy,
330328
// and update the value map with the newly created values.
331329
for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
332330
BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F);
333331
NewBlocks.push_back(NewBB);
334332

335-
// If we're unrolling the outermost loop, there's no remainder loop,
336-
// and this block isn't in a nested loop, then the new block is not
337-
// in any loop. Otherwise, add it to loopinfo.
338-
if (CreateRemainderLoop || LI->getLoopFor(*BB) != L || ParentLoop)
339-
addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops);
333+
addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops);
340334

341335
VMap[*BB] = NewBB;
342336
if (Header == *BB) {
@@ -357,27 +351,22 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop,
357351
}
358352

359353
if (Latch == *BB) {
360-
// For the last block, if CreateRemainderLoop is false, create a direct
361-
// jump to InsertBot. If not, create a loop back to cloned head.
354+
// For the last block, create a loop back to cloned head.
362355
VMap.erase((*BB)->getTerminator());
363356
BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]);
364357
BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator());
365358
IRBuilder<> Builder(LatchBR);
366-
if (!CreateRemainderLoop) {
367-
Builder.CreateBr(InsertBot);
368-
} else {
369-
PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2,
370-
suffix + ".iter",
371-
FirstLoopBB->getFirstNonPHI());
372-
Value *IdxSub =
373-
Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1),
374-
NewIdx->getName() + ".sub");
375-
Value *IdxCmp =
376-
Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp");
377-
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot);
378-
NewIdx->addIncoming(NewIter, InsertTop);
379-
NewIdx->addIncoming(IdxSub, NewBB);
380-
}
359+
PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2,
360+
suffix + ".iter",
361+
FirstLoopBB->getFirstNonPHI());
362+
Value *IdxSub =
363+
Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1),
364+
NewIdx->getName() + ".sub");
365+
Value *IdxCmp =
366+
Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp");
367+
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot);
368+
NewIdx->addIncoming(NewIter, InsertTop);
369+
NewIdx->addIncoming(IdxSub, NewBB);
381370
LatchBR->eraseFromParent();
382371
}
383372
}
@@ -386,28 +375,15 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop,
386375
// cloned loop.
387376
for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
388377
PHINode *NewPHI = cast<PHINode>(VMap[&*I]);
389-
if (!CreateRemainderLoop) {
390-
if (UseEpilogRemainder) {
391-
unsigned idx = NewPHI->getBasicBlockIndex(Preheader);
392-
NewPHI->setIncomingBlock(idx, InsertTop);
393-
NewPHI->removeIncomingValue(Latch, false);
394-
} else {
395-
VMap[&*I] = NewPHI->getIncomingValueForBlock(Preheader);
396-
cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI);
397-
}
398-
} else {
399-
unsigned idx = NewPHI->getBasicBlockIndex(Preheader);
400-
NewPHI->setIncomingBlock(idx, InsertTop);
401-
BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]);
402-
idx = NewPHI->getBasicBlockIndex(Latch);
403-
Value *InVal = NewPHI->getIncomingValue(idx);
404-
NewPHI->setIncomingBlock(idx, NewLatch);
405-
if (Value *V = VMap.lookup(InVal))
406-
NewPHI->setIncomingValue(idx, V);
407-
}
378+
unsigned idx = NewPHI->getBasicBlockIndex(Preheader);
379+
NewPHI->setIncomingBlock(idx, InsertTop);
380+
BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]);
381+
idx = NewPHI->getBasicBlockIndex(Latch);
382+
Value *InVal = NewPHI->getIncomingValue(idx);
383+
NewPHI->setIncomingBlock(idx, NewLatch);
384+
if (Value *V = VMap.lookup(InVal))
385+
NewPHI->setIncomingValue(idx, V);
408386
}
409-
if (!CreateRemainderLoop)
410-
return nullptr;
411387

412388
Loop *NewLoop = NewLoops[L];
413389
assert(NewLoop && "L should have been cloned");
@@ -819,18 +795,13 @@ bool llvm::UnrollRuntimeLoopRemainder(
819795
std::vector<BasicBlock *> NewBlocks;
820796
ValueToValueMapTy VMap;
821797

822-
// For unroll factor 2 remainder loop will have 1 iterations.
823-
// Do not create 1 iteration loop.
824-
bool CreateRemainderLoop = (Count != 2);
825-
826798
// Clone all the basic blocks in the loop. If Count is 2, we don't clone
827799
// the loop, otherwise we create a cloned loop to execute the extra
828800
// iterations. This function adds the appropriate CFG connections.
829801
BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
830802
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
831803
Loop *remainderLoop = CloneLoopBlocks(
832-
L, ModVal, CreateRemainderLoop, UseEpilogRemainder, UnrollRemainder,
833-
InsertTop, InsertBot,
804+
L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
834805
NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI);
835806

836807
// Assign the maximum possible trip count as the back edge weight for the
@@ -974,6 +945,42 @@ bool llvm::UnrollRuntimeLoopRemainder(
974945
assert(DT->verify(DominatorTree::VerificationLevel::Full));
975946
#endif
976947

948+
// For unroll factor 2 remainder loop will have 1 iteration.
949+
if (Count == 2 && DT && LI && SE) {
950+
// TODO: This code could probably be pulled out into a helper function
951+
// (e.g. breakLoopBackedgeAndSimplify) and reused in loop-deletion.
952+
BasicBlock *RemainderLatch = remainderLoop->getLoopLatch();
953+
assert(RemainderLatch);
954+
SmallVector<BasicBlock*> RemainderBlocks(remainderLoop->getBlocks().begin(),
955+
remainderLoop->getBlocks().end());
956+
breakLoopBackedge(remainderLoop, *DT, *SE, *LI, nullptr);
957+
remainderLoop = nullptr;
958+
959+
// Simplify loop values after breaking the backedge
960+
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
961+
SmallVector<WeakTrackingVH, 16> DeadInsts;
962+
for (BasicBlock *BB : RemainderBlocks) {
963+
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
964+
Instruction *Inst = &*I++;
965+
if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC}))
966+
if (LI->replacementPreservesLCSSAForm(Inst, V))
967+
Inst->replaceAllUsesWith(V);
968+
if (isInstructionTriviallyDead(Inst))
969+
DeadInsts.emplace_back(Inst);
970+
}
971+
// We can't do recursive deletion until we're done iterating, as we might
972+
// have a phi which (potentially indirectly) uses instructions later in
973+
// the block we're iterating through.
974+
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
975+
}
976+
977+
// Merge latch into exit block.
978+
auto *ExitBB = RemainderLatch->getSingleSuccessor();
979+
assert(ExitBB && "required after breaking cond br backedge");
980+
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
981+
MergeBlockIntoPredecessor(ExitBB, &DTU, LI);
982+
}
983+
977984
// Canonicalize to LoopSimplifyForm both original and remainder loops. We
978985
// cannot rely on the LoopUnrollPass to do this because it only does
979986
// canonicalization for parent/subloops and not the sibling loops.

0 commit comments

Comments
 (0)