Skip to content

[LoopFlatten] Use loop versioning when overflow can't be disproven #78576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/SimplifyIndVar.h"
#include <optional>
Expand Down Expand Up @@ -97,6 +98,10 @@ static cl::opt<bool>
cl::desc("Widen the loop induction variables, if possible, so "
"overflow checks won't reject flattening"));

static cl::opt<bool>
VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true),
cl::desc("Version loops if flattened loop could overflow"));

namespace {
// We require all uses of both induction variables to match this pattern:
//
Expand Down Expand Up @@ -141,6 +146,8 @@ struct FlattenInfo {
// has been applied. Used to skip
// checks on phi nodes.

Value *NewTripCount = nullptr; // The tripcount of the flattened loop.

FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};

bool isNarrowInductionPhi(PHINode *Phi) {
Expand Down Expand Up @@ -752,11 +759,13 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
ORE.emit(Remark);
}

Value *NewTripCount = BinaryOperator::CreateMul(
FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
FI.OuterLoop->getLoopPreheader()->getTerminator());
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
NewTripCount->dump());
if (!FI.NewTripCount) {
FI.NewTripCount = BinaryOperator::CreateMul(
FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
FI.OuterLoop->getLoopPreheader()->getTerminator());
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
FI.NewTripCount->dump());
}

// Fix up PHI nodes that take values from the inner loop back-edge, which
// we are about to remove.
Expand All @@ -769,7 +778,7 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,

// Modify the trip count of the outer loop to be the product of the two
// trip counts.
cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount);

// Replace the inner loop backedge with an unconditional branch to the exit.
BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
Expand Down Expand Up @@ -891,7 +900,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
ScalarEvolution *SE, AssumptionCache *AC,
const TargetTransformInfo *TTI, LPMUpdater *U,
MemorySSAUpdater *MSSAU) {
MemorySSAUpdater *MSSAU,
const LoopAccessInfo &LAI) {
LLVM_DEBUG(
dbgs() << "Loop flattening running on outer loop "
<< FI.OuterLoop->getHeader()->getName() << " and inner loop "
Expand Down Expand Up @@ -926,18 +936,55 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
// variable might overflow. In this case, we need to version the loop, and
// select the original version at runtime if the iteration space is too
// large.
// TODO: We currently don't version the loop.
OverflowResult OR = checkOverflow(FI, DT, AC);
if (OR == OverflowResult::AlwaysOverflowsHigh ||
OR == OverflowResult::AlwaysOverflowsLow) {
LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n");
return false;
} else if (OR == OverflowResult::MayOverflow) {
LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
return false;
Module *M = FI.OuterLoop->getHeader()->getParent()->getParent();
const DataLayout &DL = M->getDataLayout();
if (!VersionLoops) {
LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n");
return false;
} else if (!DL.isLegalInteger(
FI.OuterTripCount->getType()->getScalarSizeInBits())) {
// If the trip count type isn't legal then it won't be possible to check
// for overflow using only a single multiply instruction, so don't
// flatten.
LLVM_DEBUG(
dbgs() << "Can't check overflow efficiently, not flattening\n");
return false;
}
LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n");

// Version the loop. The overflow check isn't a runtime pointer check, so we
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an assertion in LoopVersioning::versionLoop to make sure that some runtime checks are emitted. There is also some code in there which generates checks using SCEV, I'm guessing that's what allows this to work? If so, could you expand this comment to explain why that is guaranteed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've expanded on the comment and also added an assert checking that the branch condition is false as expected.

// pass an empty list of runtime pointer checks, causing LoopVersioning to
// emit 'false' as the branch condition, and add our own check afterwards.
BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader();
ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr);
LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE);
LVer.versionLoop();

// Check for overflow by calculating the new tripcount using
// umul_with_overflow and then checking if it overflowed.
BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator());
assert(Br->isConditional() &&
"Expected LoopVersioning to generate a conditional branch");
assert(match(Br->getCondition(), m_Zero()) &&
"Expected branch condition to be false");
IRBuilder<> Builder(Br);
Function *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow,
FI.OuterTripCount->getType());
Value *Call = Builder.CreateCall(F, {FI.OuterTripCount, FI.InnerTripCount},
"flatten.mul");
FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount");
Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow");
Br->setCondition(Overflow);
} else {
LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
}

LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
}

Expand All @@ -958,13 +1005,15 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
// in simplified form, and also needs LCSSA. Running
// this pass will simplify all loops that contain inner loops,
// regardless of whether anything ends up being flattened.
LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, nullptr);
for (Loop *InnerLoop : LN.getLoops()) {
auto *OuterLoop = InnerLoop->getParentLoop();
if (!OuterLoop)
continue;
FlattenInfo FI(OuterLoop, InnerLoop);
Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
MSSAU ? &*MSSAU : nullptr);
Changed |=
FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop));
}

if (!Changed)
Expand Down
114 changes: 0 additions & 114 deletions llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll
Original file line number Diff line number Diff line change
Expand Up @@ -568,72 +568,6 @@ for.cond.cleanup:
ret void
}

; A 3d loop corresponding to:
;
; for (int k = 0; k < N; ++k)
; for (int i = 0; i < N; ++i)
; for (int j = 0; j < M; ++j)
; f(&A[i*M+j]);
;
; This could be supported, but isn't at the moment.
;
define void @d3_2(i32* %A, i32 %N, i32 %M) {
entry:
%cmp30 = icmp sgt i32 %N, 0
br i1 %cmp30, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup

for.cond1.preheader.lr.ph:
%cmp625 = icmp sgt i32 %M, 0
br label %for.cond1.preheader.us

for.cond1.preheader.us:
%k.031.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ]
br i1 %cmp625, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us43.preheader

for.cond5.preheader.us43.preheader:
br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50

for.cond5.preheader.us.us.preheader:
br label %for.cond5.preheader.us.us

for.cond1.for.cond.cleanup3_crit_edge.us.loopexit:
br label %for.cond1.for.cond.cleanup3_crit_edge.us

for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50:
br label %for.cond1.for.cond.cleanup3_crit_edge.us

for.cond1.for.cond.cleanup3_crit_edge.us:
%inc13.us = add nuw nsw i32 %k.031.us, 1
%exitcond52 = icmp ne i32 %inc13.us, %N
br i1 %exitcond52, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit

for.cond5.preheader.us.us:
%i.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ]
%mul.us.us = mul nsw i32 %i.028.us.us, %M
br label %for.body8.us.us

for.cond5.for.cond.cleanup7_crit_edge.us.us:
%inc10.us.us = add nuw nsw i32 %i.028.us.us, 1
%exitcond51 = icmp ne i32 %inc10.us.us, %N
br i1 %exitcond51, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit

for.body8.us.us:
%j.026.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ]
%add.us.us = add nsw i32 %j.026.us.us, %mul.us.us
%idxprom.us.us = sext i32 %add.us.us to i64
%arrayidx.us.us = getelementptr inbounds i32, ptr %A, i64 %idxprom.us.us
tail call void @f(ptr %arrayidx.us.us) #2
%inc.us.us = add nuw nsw i32 %j.026.us.us, 1
%exitcond = icmp ne i32 %inc.us.us, %M
br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us

for.cond.cleanup.loopexit:
br label %for.cond.cleanup

for.cond.cleanup:
ret void
}

; A 3d loop corresponding to:
;
; for (int i = 0; i < N; ++i)
Expand Down Expand Up @@ -785,54 +719,6 @@ for.empty:
ret void
}

; GEP doesn't dominate the loop latch so can't guarantee N*M won't overflow.
@first = global i32 1, align 4
@a = external global [0 x i8], align 1
define void @overflow(i32 %lim, ptr %a) {
entry:
%cmp17.not = icmp eq i32 %lim, 0
br i1 %cmp17.not, label %for.cond.cleanup, label %for.cond1.preheader.preheader

for.cond1.preheader.preheader:
br label %for.cond1.preheader

for.cond1.preheader:
%i.018 = phi i32 [ %inc6, %for.cond.cleanup3 ], [ 0, %for.cond1.preheader.preheader ]
%mul = mul i32 %i.018, 100000
br label %for.body4

for.cond.cleanup.loopexit:
br label %for.cond.cleanup

for.cond.cleanup:
ret void

for.cond.cleanup3:
%inc6 = add i32 %i.018, 1
%cmp = icmp ult i32 %inc6, %lim
br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup.loopexit

for.body4:
%j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %if.end ]
%add = add i32 %j.016, %mul
%0 = load i32, ptr @first, align 4
%tobool.not = icmp eq i32 %0, 0
br i1 %tobool.not, label %if.end, label %if.then

if.then:
%arrayidx = getelementptr inbounds [0 x i8], ptr @a, i32 0, i32 %add
%1 = load i8, ptr %arrayidx, align 1
tail call void asm sideeffect "", "r"(i8 %1)
store i32 0, ptr @first, align 4
br label %if.end

if.end:
tail call void asm sideeffect "", "r"(i32 %add)
%inc = add nuw nsw i32 %j.016, 1
%cmp2 = icmp ult i32 %j.016, 99999
br i1 %cmp2, label %for.body4, label %for.cond.cleanup3
}

declare void @objc_enumerationMutation(ptr)
declare dso_local void @f(ptr)
declare dso_local void @g(...)
Loading