Skip to content

Commit a04d4a0

Browse files
[LoopFlatten] Use loop versioning when overflow can't be disproven (#78576)
Implement the TODO in loop flattening to version the loop when we can't prove that the trip count calculation won't overflow.
1 parent 45cc2a1 commit a04d4a0

File tree

4 files changed

+482
-129
lines changed

4 files changed

+482
-129
lines changed

llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
#include "llvm/Transforms/Scalar/LoopPassManager.h"
7171
#include "llvm/Transforms/Utils/Local.h"
7272
#include "llvm/Transforms/Utils/LoopUtils.h"
73+
#include "llvm/Transforms/Utils/LoopVersioning.h"
7374
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
7475
#include "llvm/Transforms/Utils/SimplifyIndVar.h"
7576
#include <optional>
@@ -97,6 +98,10 @@ static cl::opt<bool>
9798
cl::desc("Widen the loop induction variables, if possible, so "
9899
"overflow checks won't reject flattening"));
99100

101+
static cl::opt<bool>
102+
VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true),
103+
cl::desc("Version loops if flattened loop could overflow"));
104+
100105
namespace {
101106
// We require all uses of both induction variables to match this pattern:
102107
//
@@ -141,6 +146,8 @@ struct FlattenInfo {
141146
// has been applied. Used to skip
142147
// checks on phi nodes.
143148

149+
Value *NewTripCount = nullptr; // The tripcount of the flattened loop.
150+
144151
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
145152

146153
bool isNarrowInductionPhi(PHINode *Phi) {
@@ -752,11 +759,13 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
752759
ORE.emit(Remark);
753760
}
754761

755-
Value *NewTripCount = BinaryOperator::CreateMul(
756-
FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
757-
FI.OuterLoop->getLoopPreheader()->getTerminator());
758-
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
759-
NewTripCount->dump());
762+
if (!FI.NewTripCount) {
763+
FI.NewTripCount = BinaryOperator::CreateMul(
764+
FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
765+
FI.OuterLoop->getLoopPreheader()->getTerminator());
766+
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
767+
FI.NewTripCount->dump());
768+
}
760769

761770
// Fix up PHI nodes that take values from the inner loop back-edge, which
762771
// we are about to remove.
@@ -769,7 +778,7 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
769778

770779
// Modify the trip count of the outer loop to be the product of the two
771780
// trip counts.
772-
cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
781+
cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount);
773782

774783
// Replace the inner loop backedge with an unconditional branch to the exit.
775784
BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
@@ -891,7 +900,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
891900
static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
892901
ScalarEvolution *SE, AssumptionCache *AC,
893902
const TargetTransformInfo *TTI, LPMUpdater *U,
894-
MemorySSAUpdater *MSSAU) {
903+
MemorySSAUpdater *MSSAU,
904+
const LoopAccessInfo &LAI) {
895905
LLVM_DEBUG(
896906
dbgs() << "Loop flattening running on outer loop "
897907
<< FI.OuterLoop->getHeader()->getName() << " and inner loop "
@@ -926,18 +936,55 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
926936
// variable might overflow. In this case, we need to version the loop, and
927937
// select the original version at runtime if the iteration space is too
928938
// large.
929-
// TODO: We currently don't version the loop.
930939
OverflowResult OR = checkOverflow(FI, DT, AC);
931940
if (OR == OverflowResult::AlwaysOverflowsHigh ||
932941
OR == OverflowResult::AlwaysOverflowsLow) {
933942
LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
934943
return false;
935944
} else if (OR == OverflowResult::MayOverflow) {
936-
LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
937-
return false;
945+
Module *M = FI.OuterLoop->getHeader()->getParent()->getParent();
946+
const DataLayout &DL = M->getDataLayout();
947+
if (!VersionLoops) {
948+
LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
949+
return false;
950+
} else if (!DL.isLegalInteger(
951+
FI.OuterTripCount->getType()->getScalarSizeInBits())) {
952+
// If the trip count type isn't legal then it won't be possible to check
953+
// for overflow using only a single multiply instruction, so don't
954+
// flatten.
955+
LLVM_DEBUG(
956+
dbgs() << "Can't check overflow efficiently, not flattening\n");
957+
return false;
958+
}
959+
LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n");
960+
961+
// Version the loop. The overflow check isn't a runtime pointer check, so we
962+
// pass an empty list of runtime pointer checks, causing LoopVersioning to
963+
// emit 'false' as the branch condition, and add our own check afterwards.
964+
BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader();
965+
ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr);
966+
LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE);
967+
LVer.versionLoop();
968+
969+
// Check for overflow by calculating the new tripcount using
970+
// umul_with_overflow and then checking if it overflowed.
971+
BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator());
972+
assert(Br->isConditional() &&
973+
"Expected LoopVersioning to generate a conditional branch");
974+
assert(match(Br->getCondition(), m_Zero()) &&
975+
"Expected branch condition to be false");
976+
IRBuilder<> Builder(Br);
977+
Function *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow,
978+
FI.OuterTripCount->getType());
979+
Value *Call = Builder.CreateCall(F, {FI.OuterTripCount, FI.InnerTripCount},
980+
"flatten.mul");
981+
FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount");
982+
Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow");
983+
Br->setCondition(Overflow);
984+
} else {
985+
LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
938986
}
939987

940-
LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
941988
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
942989
}
943990

@@ -958,13 +1005,15 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
9581005
// in simplified form, and also needs LCSSA. Running
9591006
// this pass will simplify all loops that contain inner loops,
9601007
// regardless of whether anything ends up being flattened.
1008+
LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, nullptr);
9611009
for (Loop *InnerLoop : LN.getLoops()) {
9621010
auto *OuterLoop = InnerLoop->getParentLoop();
9631011
if (!OuterLoop)
9641012
continue;
9651013
FlattenInfo FI(OuterLoop, InnerLoop);
966-
Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
967-
MSSAU ? &*MSSAU : nullptr);
1014+
Changed |=
1015+
FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
1016+
MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop));
9681017
}
9691018

9701019
if (!Changed)

llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -568,72 +568,6 @@ for.cond.cleanup:
568568
ret void
569569
}
570570

571-
; A 3d loop corresponding to:
572-
;
573-
; for (int k = 0; k < N; ++k)
574-
; for (int i = 0; i < N; ++i)
575-
; for (int j = 0; j < M; ++j)
576-
; f(&A[i*M+j]);
577-
;
578-
; This could be supported, but isn't at the moment.
579-
;
580-
define void @d3_2(i32* %A, i32 %N, i32 %M) {
581-
entry:
582-
%cmp30 = icmp sgt i32 %N, 0
583-
br i1 %cmp30, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup
584-
585-
for.cond1.preheader.lr.ph:
586-
%cmp625 = icmp sgt i32 %M, 0
587-
br label %for.cond1.preheader.us
588-
589-
for.cond1.preheader.us:
590-
%k.031.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
591-
br i1 %cmp625, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us43.preheader
592-
593-
for.cond5.preheader.us43.preheader:
594-
br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50
595-
596-
for.cond5.preheader.us.us.preheader:
597-
br label %for.cond5.preheader.us.us
598-
599-
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
600-
br label %for.cond1.for.cond.cleanup3_crit_edge.us
601-
602-
for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50:
603-
br label %for.cond1.for.cond.cleanup3_crit_edge.us
604-
605-
for.cond1.for.cond.cleanup3_crit_edge.us:
606-
%inc13.us = add nuw nsw i32 %k.031.us, 1
607-
%exitcond52 = icmp ne i32 %inc13.us, %N
608-
br i1 %exitcond52, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit
609-
610-
for.cond5.preheader.us.us:
611-
%i.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
612-
%mul.us.us = mul nsw i32 %i.028.us.us, %M
613-
br label %for.body8.us.us
614-
615-
for.cond5.for.cond.cleanup7_crit_edge.us.us:
616-
%inc10.us.us = add nuw nsw i32 %i.028.us.us, 1
617-
%exitcond51 = icmp ne i32 %inc10.us.us, %N
618-
br i1 %exitcond51, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit
619-
620-
for.body8.us.us:
621-
%j.026.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
622-
%add.us.us = add nsw i32 %j.026.us.us, %mul.us.us
623-
%idxprom.us.us = sext i32 %add.us.us to i64
624-
%arrayidx.us.us = getelementptr inbounds i32, ptr %A, i64 %idxprom.us.us
625-
tail call void @f(ptr %arrayidx.us.us) #2
626-
%inc.us.us = add nuw nsw i32 %j.026.us.us, 1
627-
%exitcond = icmp ne i32 %inc.us.us, %M
628-
br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us
629-
630-
for.cond.cleanup.loopexit:
631-
br label %for.cond.cleanup
632-
633-
for.cond.cleanup:
634-
ret void
635-
}
636-
637571
; A 3d loop corresponding to:
638572
;
639573
; for (int i = 0; i < N; ++i)
@@ -785,54 +719,6 @@ for.empty:
785719
ret void
786720
}
787721

788-
; GEP doesn't dominate the loop latch so can't guarantee N*M won't overflow.
789-
@first = global i32 1, align 4
790-
@a = external global [0 x i8], align 1
791-
define void @overflow(i32 %lim, ptr %a) {
792-
entry:
793-
%cmp17.not = icmp eq i32 %lim, 0
794-
br i1 %cmp17.not, label %for.cond.cleanup, label %for.cond1.preheader.preheader
795-
796-
for.cond1.preheader.preheader:
797-
br label %for.cond1.preheader
798-
799-
for.cond1.preheader:
800-
%i.018 = phi i32 [ %inc6, %for.cond.cleanup3 ], [ 0, %for.cond1.preheader.preheader ]
801-
%mul = mul i32 %i.018, 100000
802-
br label %for.body4
803-
804-
for.cond.cleanup.loopexit:
805-
br label %for.cond.cleanup
806-
807-
for.cond.cleanup:
808-
ret void
809-
810-
for.cond.cleanup3:
811-
%inc6 = add i32 %i.018, 1
812-
%cmp = icmp ult i32 %inc6, %lim
813-
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup.loopexit
814-
815-
for.body4:
816-
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %if.end ]
817-
%add = add i32 %j.016, %mul
818-
%0 = load i32, ptr @first, align 4
819-
%tobool.not = icmp eq i32 %0, 0
820-
br i1 %tobool.not, label %if.end, label %if.then
821-
822-
if.then:
823-
%arrayidx = getelementptr inbounds [0 x i8], ptr @a, i32 0, i32 %add
824-
%1 = load i8, ptr %arrayidx, align 1
825-
tail call void asm sideeffect "", "r"(i8 %1)
826-
store i32 0, ptr @first, align 4
827-
br label %if.end
828-
829-
if.end:
830-
tail call void asm sideeffect "", "r"(i32 %add)
831-
%inc = add nuw nsw i32 %j.016, 1
832-
%cmp2 = icmp ult i32 %j.016, 99999
833-
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
834-
}
835-
836722
declare void @objc_enumerationMutation(ptr)
837723
declare dso_local void @f(ptr)
838724
declare dso_local void @g(...)

0 commit comments

Comments
 (0)