Skip to content

Commit 62e1786

Browse files
committed
[LoopVectorize] Enable vectorisation of early exit loops with live-outs
This work feeds part of PR #88385, and adds support for vectorising loops with uncountable early exits and outside users of loop-defined variables. I've added a new fixupEarlyExitIVUsers to mirror what happens in fixupIVUsers when patching up outside users of induction variables in the early exit block. We have to handle these differently for two reasons: 1. We can't work backwards from the end value in the middle block because we didn't leave at the last iteration. 2. We need to generate different IR that calculates the vector lane that triggered the exit, and hence can determine the induction value at the point we exited. I've added a new 'null' VPValue as a dummy placeholder to manage the incoming operands of PHI nodes in the exit block. We can have situations where one of the incoming values is an induction variable (or its update) and the other is not. For example, both the latch and the early exiting block can jump to the same exit block. However, VPInstruction::generate walks through all predecessors of the PHI assuming the value is *not* an IV. In order to ensure that we process the right value for the right incoming block we use this new 'null' value is a marker to indicate it should be skipped, since it will be handled separately in fixupIVUsers or fixupEarlyExitIVUsers. All code for calculating the last value when exiting the loop early now lives in a new vector.early.exit block, which sits between the middle.split block and the original exit block. I also had to fix up the vplan verifier because it assumed that the block containing a definition always dominated the parent of the user. That's no longer the case because we can arrive at the exit block via one of the latch or the early exiting block. I've added a new ExtractFirstActive VPInstruction that extracts the first active lane of a vector, i.e. the lane of the vector predicate that triggered the exit.
1 parent 1feeeb4 commit 62e1786

14 files changed

+1186
-201
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 174 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,11 @@ class InnerLoopVectorizer {
548548
Value *VectorTripCount, BasicBlock *MiddleBlock,
549549
VPTransformState &State);
550550

551+
void fixupEarlyExitIVUsers(PHINode *OrigPhi, const InductionDescriptor &II,
552+
BasicBlock *VectorEarlyExitBB,
553+
BasicBlock *MiddleBlock, VPlan &Plan,
554+
VPTransformState &State);
555+
551556
/// Iteratively sink the scalarized operands of a predicated instruction into
552557
/// the block that was created for it.
553558
void sinkScalarOperands(Instruction *PredInst);
@@ -2775,6 +2780,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
27752780
return LoopVectorPreHeader;
27762781
}
27772782

2783+
static bool isValueIncomingFromBlock(BasicBlock *ExitingBB, Value *V,
2784+
Instruction *UI) {
2785+
PHINode *PHI = dyn_cast<PHINode>(UI);
2786+
assert(PHI && "Expected LCSSA form");
2787+
2788+
// If this loop has an uncountable early exit then there could be
2789+
// different users of OrigPhi with either:
2790+
// 1. Multiple users, because each exiting block (countable or
2791+
// uncountable) jumps to the same exit block, or ..
2792+
// 2. A single user with an incoming value from a countable or
2793+
// uncountable exiting block.
2794+
// In both cases there is no guarantee this came from a countable exiting
2795+
// block, i.e. the latch.
2796+
int Index = PHI->getBasicBlockIndex(ExitingBB);
2797+
return Index != -1 && PHI->getIncomingValue(Index) == V;
2798+
}
2799+
27782800
// Fix up external users of the induction variable. At this point, we are
27792801
// in LCSSA form, with all external PHIs that use the IV having one input value,
27802802
// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2790,19 +2812,20 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
27902812
// We allow both, but they, obviously, have different values.
27912813

27922814
DenseMap<Value *, Value *> MissingVals;
2815+
BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
27932816

27942817
Value *EndValue = cast<PHINode>(OrigPhi->getIncomingValueForBlock(
27952818
OrigLoop->getLoopPreheader()))
27962819
->getIncomingValueForBlock(MiddleBlock);
27972820

27982821
// An external user of the last iteration's value should see the value that
27992822
// the remainder loop uses to initialize its own IV.
2800-
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch());
2823+
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
28012824
for (User *U : PostInc->users()) {
28022825
Instruction *UI = cast<Instruction>(U);
28032826
if (!OrigLoop->contains(UI)) {
2804-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2805-
MissingVals[UI] = EndValue;
2827+
if (isValueIncomingFromBlock(OrigLoopLatch, PostInc, UI))
2828+
MissingVals[cast<PHINode>(UI)] = EndValue;
28062829
}
28072830
}
28082831

@@ -2812,7 +2835,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28122835
for (User *U : OrigPhi->users()) {
28132836
auto *UI = cast<Instruction>(U);
28142837
if (!OrigLoop->contains(UI)) {
2815-
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2838+
if (!isValueIncomingFromBlock(OrigLoopLatch, OrigPhi, UI))
2839+
continue;
2840+
28162841
IRBuilder<> B(MiddleBlock->getTerminator());
28172842

28182843
// Fast-math-flags propagate from the original induction instruction.
@@ -2842,18 +2867,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28422867
}
28432868
}
28442869

2845-
assert((MissingVals.empty() ||
2846-
all_of(MissingVals,
2847-
[MiddleBlock, this](const std::pair<Value *, Value *> &P) {
2848-
return all_of(
2849-
predecessors(cast<Instruction>(P.first)->getParent()),
2850-
[MiddleBlock, this](BasicBlock *Pred) {
2851-
return Pred == MiddleBlock ||
2852-
Pred == OrigLoop->getLoopLatch();
2853-
});
2854-
})) &&
2855-
"Expected escaping values from latch/middle.block only");
2856-
28572870
for (auto &I : MissingVals) {
28582871
PHINode *PHI = cast<PHINode>(I.first);
28592872
// One corner case we have to handle is two IVs "chasing" each-other,
@@ -2866,6 +2879,102 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28662879
}
28672880
}
28682881

2882+
void InnerLoopVectorizer::fixupEarlyExitIVUsers(PHINode *OrigPhi,
2883+
const InductionDescriptor &II,
2884+
BasicBlock *VectorEarlyExitBB,
2885+
BasicBlock *MiddleBlock,
2886+
VPlan &Plan,
2887+
VPTransformState &State) {
2888+
// There are two kinds of external IV usages - those that use the value
2889+
// computed in the last iteration (the PHI) and those that use the penultimate
2890+
// value (the value that feeds into the phi from the loop latch).
2891+
// We allow both, but they, obviously, have different values.
2892+
DenseMap<Value *, Value *> MissingVals;
2893+
BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch();
2894+
BasicBlock *EarlyExitingBB = Legal->getUncountableEarlyExitingBlock();
2895+
Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoopLatch);
2896+
2897+
// Obtain the canonical IV, since we have to use the most recent value
2898+
// before exiting the loop early. This is unlike fixupIVUsers, which has
2899+
// the luxury of using the end value in the middle block.
2900+
VPBasicBlock *EntryVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock();
2901+
// NOTE: We cannot call Plan.getCanonicalIV() here because the original
2902+
// recipe created whilst building plans is no longer valid.
2903+
VPHeaderPHIRecipe *CanonicalIVR =
2904+
cast<VPHeaderPHIRecipe>(&*EntryVPBB->begin());
2905+
Value *CanonicalIV = State.get(CanonicalIVR->getVPSingleValue(), true);
2906+
2907+
// Search for the mask that drove us to exit early.
2908+
VPBasicBlock *EarlyExitVPBB = Plan.getVectorLoopRegion()->getEarlyExit();
2909+
VPBasicBlock *MiddleSplitVPBB =
2910+
cast<VPBasicBlock>(EarlyExitVPBB->getSinglePredecessor());
2911+
VPInstruction *BranchOnCond =
2912+
cast<VPInstruction>(MiddleSplitVPBB->getTerminator());
2913+
assert(BranchOnCond->getOpcode() == VPInstruction::BranchOnCond &&
2914+
"Expected middle.split block terminator to be a branch-on-cond");
2915+
VPInstruction *ScalarEarlyExitCond =
2916+
cast<VPInstruction>(BranchOnCond->getOperand(0));
2917+
assert(
2918+
ScalarEarlyExitCond->getOpcode() == VPInstruction::AnyOf &&
2919+
"Expected middle.split block terminator branch condition to be any-of");
2920+
VPValue *VectorEarlyExitCond = ScalarEarlyExitCond->getOperand(0);
2921+
// Finally get the mask that led us into the early exit block.
2922+
Value *EarlyExitMask = State.get(VectorEarlyExitCond);
2923+
2924+
// Calculate the IV step.
2925+
VPValue *StepVPV = Plan.getSCEVExpansion(II.getStep());
2926+
assert(StepVPV && "step must have been expanded during VPlan execution");
2927+
Value *Step = StepVPV->isLiveIn() ? StepVPV->getLiveInIRValue()
2928+
: State.get(StepVPV, VPLane(0));
2929+
2930+
auto FixUpPhi = [&](Instruction *UI, bool PostInc) -> Value * {
2931+
IRBuilder<> B(VectorEarlyExitBB->getTerminator());
2932+
assert(isa<PHINode>(UI) && "Expected LCSSA form");
2933+
2934+
// Fast-math-flags propagate from the original induction instruction.
2935+
if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp()))
2936+
B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags());
2937+
2938+
Type *CtzType = CanonicalIV->getType();
2939+
Value *Ctz = B.CreateCountTrailingZeroElems(CtzType, EarlyExitMask);
2940+
Ctz = B.CreateAdd(Ctz, cast<PHINode>(CanonicalIV));
2941+
if (PostInc)
2942+
Ctz = B.CreateAdd(Ctz, ConstantInt::get(CtzType, 1));
2943+
2944+
Value *Escape = emitTransformedIndex(B, Ctz, II.getStartValue(), Step,
2945+
II.getKind(), II.getInductionBinOp());
2946+
Escape->setName("ind.early.escape");
2947+
return Escape;
2948+
};
2949+
2950+
for (User *U : PostInc->users()) {
2951+
auto *UI = cast<Instruction>(U);
2952+
if (!OrigLoop->contains(UI)) {
2953+
if (isValueIncomingFromBlock(EarlyExitingBB, PostInc, UI))
2954+
MissingVals[UI] = FixUpPhi(UI, true);
2955+
}
2956+
}
2957+
2958+
for (User *U : OrigPhi->users()) {
2959+
auto *UI = cast<Instruction>(U);
2960+
if (!OrigLoop->contains(UI)) {
2961+
if (isValueIncomingFromBlock(EarlyExitingBB, OrigPhi, UI))
2962+
MissingVals[UI] = FixUpPhi(UI, false);
2963+
}
2964+
}
2965+
2966+
for (auto &I : MissingVals) {
2967+
PHINode *PHI = cast<PHINode>(I.first);
2968+
// One corner case we have to handle is two IVs "chasing" each-other,
2969+
// that is %IV2 = phi [...], [ %IV1, %latch ]
2970+
// In this case, if IV1 has an external use, we need to avoid adding both
2971+
// "last value of IV1" and "penultimate value of IV2". So, verify that we
2972+
// don't already have an incoming value for the middle block.
2973+
if (PHI->getBasicBlockIndex(VectorEarlyExitBB) == -1)
2974+
PHI->addIncoming(I.second, VectorEarlyExitBB);
2975+
}
2976+
}
2977+
28692978
namespace {
28702979

28712980
struct CSEDenseMapInfo {
@@ -2985,6 +3094,20 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
29853094
PSE.getSE()->forgetLoop(OrigLoop);
29863095
PSE.getSE()->forgetBlockAndLoopDispositions();
29873096

3097+
// When dealing with uncountable early exits we create middle.split blocks
3098+
// between the vector loop region and the exit block. These blocks need
3099+
// adding to any outer loop.
3100+
VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion();
3101+
Loop *OuterLoop = OrigLoop->getParentLoop();
3102+
if (Legal->hasUncountableEarlyExit() && OuterLoop) {
3103+
BasicBlock *OrigEarlyExitBB = Legal->getUncountableEarlyExitBlock();
3104+
if (Loop *EEL = LI->getLoopFor(OrigEarlyExitBB)) {
3105+
BasicBlock *VectorEarlyExitBB =
3106+
State.CFG.VPBB2IRBB[VectorRegion->getEarlyExit()];
3107+
EEL->addBasicBlockToLoop(VectorEarlyExitBB, *LI);
3108+
}
3109+
}
3110+
29883111
// After vectorization, the exit blocks of the original loop will have
29893112
// additional predecessors. Invalidate SCEVs for the exit phis in case SE
29903113
// looked through single-entry phis.
@@ -3012,15 +3135,23 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30123135
getOrCreateVectorTripCount(nullptr), LoopMiddleBlock, State);
30133136
}
30143137

3138+
if (Legal->hasUncountableEarlyExit()) {
3139+
VPBasicBlock *VectorEarlyExitVPBB =
3140+
cast<VPBasicBlock>(VectorRegion->getEarlyExit());
3141+
BasicBlock *VectorEarlyExitBB = State.CFG.VPBB2IRBB[VectorEarlyExitVPBB];
3142+
for (const auto &Entry : Legal->getInductionVars())
3143+
fixupEarlyExitIVUsers(Entry.first, Entry.second, VectorEarlyExitBB,
3144+
LoopMiddleBlock, Plan, State);
3145+
}
3146+
30153147
// Don't apply optimizations below when no vector region remains, as they all
30163148
// require a vector loop at the moment.
3017-
if (!State.Plan->getVectorLoopRegion())
3149+
if (!VectorRegion)
30183150
return;
30193151

30203152
for (Instruction *PI : PredicatedInstructions)
30213153
sinkScalarOperands(&*PI);
30223154

3023-
VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion();
30243155
VPBasicBlock *HeaderVPBB = VectorRegion->getEntryBasicBlock();
30253156
BasicBlock *HeaderBB = State.CFG.VPBB2IRBB[HeaderVPBB];
30263157

@@ -8948,6 +9079,10 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
89489079
continue;
89499080
}
89509081

9082+
assert(!Plan.getVectorLoopRegion()->getEarlyExit() &&
9083+
"Cannot handle "
9084+
"first-order recurrences with uncountable early exits");
9085+
89519086
// The backedge value provides the value to resume coming out of a loop,
89529087
// which for FORs is a vector whose last element needs to be extracted. The
89539088
// start value provides the value if the loop is bypassed.
@@ -9056,9 +9191,8 @@ collectUsersInExitBlocks(Loop *OrigLoop, VPRecipeBuilder &Builder,
90569191
// Exit values for inductions are computed and updated outside of VPlan
90579192
// and independent of induction recipes.
90589193
// TODO: Compute induction exit values in VPlan.
9059-
if (isOptimizableIVOrUse(V) &&
9060-
ExitVPBB->getSinglePredecessor() == MiddleVPBB)
9061-
continue;
9194+
if (isOptimizableIVOrUse(V))
9195+
V = VPValue::getNull();
90629196
ExitUsersToFix.insert(ExitIRI);
90639197
ExitIRI->addOperand(V);
90649198
}
@@ -9085,18 +9219,30 @@ addUsersInExitBlocks(VPlan &Plan,
90859219
for (const auto &[Idx, Op] : enumerate(ExitIRI->operands())) {
90869220
// Pass live-in values used by exit phis directly through to their users
90879221
// in the exit block.
9088-
if (Op->isLiveIn())
9222+
if (Op->isLiveIn() || Op->isNull())
90899223
continue;
90909224

90919225
// Currently only live-ins can be used by exit values from blocks not
90929226
// exiting via the vector latch through to the middle block.
9093-
if (ExitIRI->getParent()->getSinglePredecessor() != MiddleVPBB)
9094-
return false;
9095-
90969227
LLVMContext &Ctx = ExitIRI->getInstruction().getContext();
9097-
VPValue *Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9098-
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9099-
IntegerType::get(Ctx, 32), 1))});
9228+
VPValue *Ext;
9229+
VPBasicBlock *PredVPBB =
9230+
cast<VPBasicBlock>(ExitIRI->getParent()->getPredecessors()[Idx]);
9231+
if (PredVPBB != MiddleVPBB) {
9232+
VPBasicBlock *VectorEarlyExitVPBB =
9233+
Plan.getVectorLoopRegion()->getEarlyExit();
9234+
VPBuilder B2(VectorEarlyExitVPBB,
9235+
VectorEarlyExitVPBB->getFirstNonPhi());
9236+
assert(ExitIRI->getParent()->getNumPredecessors() <= 2);
9237+
VPValue *EarlyExitMask =
9238+
Plan.getVectorLoopRegion()->getVectorEarlyExitCond();
9239+
Ext = B2.createNaryOp(VPInstruction::ExtractFirstActive,
9240+
{Op, EarlyExitMask});
9241+
} else {
9242+
Ext = B.createNaryOp(VPInstruction::ExtractFromEnd,
9243+
{Op, Plan.getOrAddLiveIn(ConstantInt::get(
9244+
IntegerType::get(Ctx, 32), 1))});
9245+
}
91009246
ExitIRI->setOperand(Idx, Ext);
91019247
}
91029248
}

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ Value *VPLane::getAsRuntimeExpr(IRBuilderBase &Builder,
8383
llvm_unreachable("Unknown lane kind");
8484
}
8585

86+
static VPValue NullValue;
87+
VPValue *VPValue::Null = &NullValue;
88+
8689
VPValue::VPValue(const unsigned char SC, Value *UV, VPDef *Def)
8790
: SubclassID(SC), UnderlyingVal(UV), Def(Def) {
8891
if (Def)

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,9 @@ class VPInstruction : public VPRecipeWithIRFlags,
12251225
// Returns a scalar boolean value, which is true if any lane of its single
12261226
// operand is true.
12271227
AnyOf,
1228+
// Extracts the first active lane of a vector, where the first operand is
1229+
// the predicate, and the second operand is the vector to extract.
1230+
ExtractFirstActive,
12281231
};
12291232

12301233
private:
@@ -3651,6 +3654,13 @@ class VPRegionBlock : public VPBlockBase {
36513654
/// VPRegionBlock.
36523655
VPBlockBase *Exiting;
36533656

3657+
/// Hold the Early Exit block of the SEME region, if one exists.
3658+
VPBasicBlock *EarlyExit;
3659+
3660+
/// If one exists, this keeps track of the vector early mask that triggered
3661+
/// the early exit.
3662+
VPValue *VectorEarlyExitCond;
3663+
36543664
/// An indicator whether this region is to generate multiple replicated
36553665
/// instances of output IR corresponding to its VPBlockBases.
36563666
bool IsReplicator;
@@ -3659,6 +3669,7 @@ class VPRegionBlock : public VPBlockBase {
36593669
VPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting,
36603670
const std::string &Name = "", bool IsReplicator = false)
36613671
: VPBlockBase(VPRegionBlockSC, Name), Entry(Entry), Exiting(Exiting),
3672+
EarlyExit(nullptr), VectorEarlyExitCond(nullptr),
36623673
IsReplicator(IsReplicator) {
36633674
assert(Entry->getPredecessors().empty() && "Entry block has predecessors.");
36643675
assert(Exiting->getSuccessors().empty() && "Exit block has successors.");
@@ -3667,6 +3678,7 @@ class VPRegionBlock : public VPBlockBase {
36673678
}
36683679
VPRegionBlock(const std::string &Name = "", bool IsReplicator = false)
36693680
: VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exiting(nullptr),
3681+
EarlyExit(nullptr), VectorEarlyExitCond(nullptr),
36703682
IsReplicator(IsReplicator) {}
36713683

36723684
~VPRegionBlock() override {}
@@ -3700,6 +3712,22 @@ class VPRegionBlock : public VPBlockBase {
37003712
ExitingBlock->setParent(this);
37013713
}
37023714

3715+
/// Sets the early exit vector mask.
3716+
void setVectorEarlyExitCond(VPValue *V) {
3717+
assert(!VectorEarlyExitCond);
3718+
VectorEarlyExitCond = V;
3719+
}
3720+
3721+
/// Gets the early exit vector mask
3722+
VPValue *getVectorEarlyExitCond() const { return VectorEarlyExitCond; }
3723+
3724+
/// Set the vector early exit block
3725+
void setEarlyExit(VPBasicBlock *ExitBlock) { EarlyExit = ExitBlock; }
3726+
3727+
/// Get the vector early exit block
3728+
const VPBasicBlock *getEarlyExit() const { return EarlyExit; }
3729+
VPBasicBlock *getEarlyExit() { return EarlyExit; }
3730+
37033731
/// Returns the pre-header VPBasicBlock of the loop region.
37043732
VPBasicBlock *getPreheaderVPBB() {
37053733
assert(!isReplicator() && "should only get pre-header of loop regions");

0 commit comments

Comments
 (0)