Skip to content

Commit e64012e

Browse files
committed
[LoopVectorize] Vectorize the reduction pattern of integer min/max with index. (2/2)
1 parent 9f4588e commit e64012e

File tree

12 files changed

+2170
-170
lines changed

12 files changed

+2170
-170
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ class RecurrenceDescriptor {
306306
/// Returns the sentinel value for FindLastIV recurrences to replace the start
307307
/// value.
308308
Value *getSentinelValue() const {
309-
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
309+
assert(
310+
(isFindLastIVRecurrenceKind(Kind) || isMinMaxIdxRecurrenceKind(Kind)) &&
311+
"Unexpected recurrence kind");
310312
Type *Ty = StartValue->getType();
311313
return ConstantInt::get(Ty,
312314
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src, Value *InitVal,
423423
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
424424
Value *Sentinel);
425425

426+
/// Create a reduction of the given vector \p Src for a reduction of the
427+
/// kind RecurKind::MinMaxFirstIdx or RecurKind::MinMaxLastIdx. The reduction
428+
/// operation is described by \p Desc.
429+
Value *createMinMaxIdxReduction(IRBuilderBase &B, Value *Src, Value *Start,
430+
const RecurrenceDescriptor &Desc);
431+
426432
/// Create an ordered reduction intrinsic using the given recurrence
427433
/// kind \p RdxKind.
428434
Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ class LoopVectorizationLegality {
307307
/// Return the fixed-order recurrences found in the loop.
308308
RecurrenceSet &getFixedOrderRecurrences() { return FixedOrderRecurrences; }
309309

310+
/// Return the min/max recurrences found in the loop.
311+
const SmallDenseMap<PHINode *, PHINode *> &getMinMaxRecurrences() {
312+
return MinMaxRecurrences;
313+
}
314+
310315
/// Returns the widest induction type.
311316
IntegerType *getWidestInductionType() { return WidestIndTy; }
312317

@@ -618,7 +623,7 @@ class LoopVectorizationLegality {
618623
RecurrenceSet FixedOrderRecurrences;
619624

620625
/// Holds the min/max recurrences variables.
621-
RecurrenceSet MinMaxRecurrences;
626+
SmallDenseMap<PHINode *, PHINode *> MinMaxRecurrences;
622627

623628
/// Holds the widest induction type encountered.
624629
IntegerType *WidestIndTy = nullptr;

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,25 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12481248
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
12491249
}
12501250

1251+
Value *llvm::createMinMaxIdxReduction(IRBuilderBase &Builder, Value *Src,
1252+
Value *Start,
1253+
const RecurrenceDescriptor &Desc) {
1254+
RecurKind Kind = Desc.getRecurrenceKind();
1255+
assert(RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1256+
"Unexpected reduction kind");
1257+
Value *Sentinel = Desc.getSentinelValue();
1258+
Value *Rdx = Src;
1259+
if (Src->getType()->isVectorTy())
1260+
Rdx = Kind == RecurKind::MinMaxFirstIdx
1261+
? Builder.CreateIntMinReduce(Src, true)
1262+
: Builder.CreateIntMaxReduce(Src, true);
1263+
// Correct the final reduction result back to the start value if the reduction
1264+
// result is sentinel value.
1265+
Value *Cmp =
1266+
Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Sentinel, "rdx.select.cmp");
1267+
return Builder.CreateSelect(Cmp, Rdx, Start, "rdx.select");
1268+
}
1269+
12511270
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
12521271
FastMathFlags Flags) {
12531272
bool Negative = false;
@@ -1336,7 +1355,8 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
13361355
RecurKind Kind) {
13371356
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
13381357
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
1339-
"AnyOf or FindLastIV reductions are not supported.");
1358+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1359+
"AnyOf, FindLastIV and MinMaxIdx reductions are not supported.");
13401360
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13411361
auto *SrcTy = cast<VectorType>(Src->getType());
13421362
Type *SrcEltTy = SrcTy->getElementType();

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
851851
if (MinMaxRecurDes.getLoopExitInstr())
852852
AllowedExit.insert(MinMaxRecurDes.getLoopExitInstr());
853853
Reductions[Phi] = MinMaxRecurDes;
854-
MinMaxRecurrences.insert(Phi);
854+
MinMaxRecurrences.try_emplace(Phi);
855855
MinMaxRecurrenceChains[Phi] = std::move(Chain);
856856
continue;
857857
}
@@ -1093,10 +1093,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
10931093
if (!canVectorizeMinMaxRecurrence(Phi, Chain))
10941094
return false;
10951095
}
1096-
// FIXME: Remove this after the IR generation of min/max with index is
1097-
// supported.
1098-
if (!MinMaxRecurrences.empty())
1099-
return false;
11001096

11011097
return true;
11021098
}
@@ -1106,6 +1102,10 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11061102
assert(!Chain.empty() && "Unexpected empty recurrence chain");
11071103
assert(isMinMaxRecurrence(Phi) && "The PHI is not a min/max recurrence phi");
11081104

1105+
auto It = MinMaxRecurrences.find(Phi);
1106+
if (It->second)
1107+
return true;
1108+
11091109
auto IsMinMaxIdxReductionPhi = [this, Phi, &Chain](Value *Candidate) -> bool {
11101110
auto *IdxPhi = dyn_cast<PHINode>(Candidate);
11111111
if (!IdxPhi || !isReductionVariable(IdxPhi))
@@ -1150,7 +1150,17 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11501150

11511151
auto *TrueVal = IdxChainHead->getTrueValue();
11521152
auto *FalseVal = IdxChainHead->getFalseValue();
1153-
return IsMinMaxIdxReductionPhi(TrueVal) || IsMinMaxIdxReductionPhi(FalseVal);
1153+
PHINode *IdxPhi;
1154+
if (IsMinMaxIdxReductionPhi(TrueVal))
1155+
IdxPhi = cast<PHINode>(TrueVal);
1156+
else if (IsMinMaxIdxReductionPhi(FalseVal))
1157+
IdxPhi = cast<PHINode>(FalseVal);
1158+
else
1159+
return false;
1160+
1161+
// Record the index reduction phi uses the min/max recurrence.
1162+
It->second = IdxPhi;
1163+
return true;
11541164
}
11551165

11561166
/// Find histogram operations that match high-level code in loops:
@@ -1973,7 +1983,8 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
19731983
SmallPtrSet<const Value *, 8> ReductionLiveOuts;
19741984

19751985
for (const auto &Reduction : getReductionVars())
1976-
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
1986+
if (auto *ExitInstr = Reduction.second.getLoopExitInstr())
1987+
ReductionLiveOuts.insert(ExitInstr);
19771988

19781989
// TODO: handle non-reduction outside users when tail is folded by masking.
19791990
for (auto *AE : AllowedExit) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4457,6 +4457,14 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
44574457
return false;
44584458
}
44594459

4460+
// TODO: support epilogue vectorization for min/max with index.
4461+
if (any_of(Legal->getReductionVars(), [](const auto &Reduction) {
4462+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
4463+
return RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(
4464+
RdxDesc.getRecurrenceKind());
4465+
}))
4466+
return false;
4467+
44604468
// Epilogue vectorization code has not been auditted to ensure it handles
44614469
// non-latch exits properly. It may be fine, but it needs auditted and
44624470
// tested.
@@ -4901,7 +4909,8 @@ LoopVectorizationCostModel::selectInterleaveCount(VPlan &Plan, ElementCount VF,
49014909
const RecurrenceDescriptor &RdxDesc = Reduction.second;
49024910
RecurKind RK = RdxDesc.getRecurrenceKind();
49034911
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
4904-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
4912+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) ||
4913+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK);
49054914
});
49064915
if (HasSelectCmpReductions) {
49074916
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -6618,6 +6627,10 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
66186627

66196628
for (const auto &Reduction : Legal->getReductionVars()) {
66206629
PHINode *Phi = Reduction.first;
6630+
// TODO: support in-loop min/max with index.
6631+
if (Legal->isMinMaxRecurrence(Phi))
6632+
continue;
6633+
66216634
const RecurrenceDescriptor &RdxDesc = Reduction.second;
66226635

66236636
// We don't collect reductions that are type promoted (yet).
@@ -7231,6 +7244,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72317244
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
72327245
return;
72337246

7247+
assert(EpiRedResult->getOpcode() != VPInstruction::ComputeMinMaxIdxResult);
7248+
72347249
auto *EpiRedHeaderPhi =
72357250
cast<VPReductionPHIRecipe>(EpiRedResult->getOperand(0));
72367251
const RecurrenceDescriptor &RdxDesc =
@@ -8143,10 +8158,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
81438158
// Find all possible partial reductions.
81448159
SmallVector<std::pair<PartialReductionChain, unsigned>>
81458160
PartialReductionChains;
8146-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8147-
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8148-
PartialReductionChains);
8149-
}
8161+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8162+
if (auto *ExitInstr = RdxDesc.getLoopExitInstr())
8163+
getScaledReductions(Phi, ExitInstr, Range, PartialReductionChains);
81508164

81518165
// A partial reduction is invalid if any of its extends are used by
81528166
// something that isn't another partial reduction. This is because the
@@ -9040,6 +9054,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
90409054
assert(
90419055
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
90429056
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9057+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
90439058
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
90449059

90459060
// Collect the chain of "link" recipes for the reduction starting at PhiR.
@@ -9163,15 +9178,32 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91639178
PreviousLink = RedRecipe;
91649179
}
91659180
}
9181+
9182+
// Collect all VPReductionPHIRecipes in the header block, and sort them based
9183+
// on the dependency order of the reductions. This ensures that results of
9184+
// min/max reductions are computed before their corresponding index
9185+
// reductions, since the index reduction relies on the result of the min/max
9186+
// reduction to determine which lane produced the min/max.
9187+
SmallVector<VPReductionPHIRecipe *> VPReductionPHIs;
9188+
for (VPRecipeBase &R : Header->phis())
9189+
if (auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R))
9190+
VPReductionPHIs.push_back(PhiR);
9191+
9192+
stable_sort(VPReductionPHIs, [this](const VPReductionPHIRecipe *R1,
9193+
const VPReductionPHIRecipe *R2) {
9194+
auto *Phi1 = cast<PHINode>(R1->getUnderlyingInstr());
9195+
if (!Legal->isMinMaxRecurrence(Phi1))
9196+
return false;
9197+
9198+
auto *Phi2 = cast<PHINode>(R2->getUnderlyingInstr());
9199+
return Legal->getMinMaxRecurrences().find(Phi1)->second == Phi2;
9200+
});
9201+
91669202
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
91679203
Builder.setInsertPoint(&*std::prev(std::prev(LatchVPBB->end())));
91689204
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
9169-
for (VPRecipeBase &R :
9170-
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
9171-
VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
9172-
if (!PhiR)
9173-
continue;
9174-
9205+
SmallDenseMap<VPReductionPHIRecipe *, VPValue *> IdxReductionMasks;
9206+
for (auto *PhiR : VPReductionPHIs) {
91759207
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
91769208
Type *PhiTy = PhiR->getOperand(0)->getLiveInIRValue()->getType();
91779209
// If tail is folded by masking, introduce selects between the phi
@@ -9198,7 +9230,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91989230
cast<VPInstruction>(&U)->getOpcode() ==
91999231
VPInstruction::ComputeReductionResult ||
92009232
cast<VPInstruction>(&U)->getOpcode() ==
9201-
VPInstruction::ComputeFindLastIVResult);
9233+
VPInstruction::ComputeFindLastIVResult ||
9234+
cast<VPInstruction>(&U)->getOpcode() ==
9235+
VPInstruction::ComputeMinMaxIdxResult);
92029236
});
92039237
if (CM.usePredicatedReductionSelect())
92049238
PhiR->setOperand(1, NewExitingVPV);
@@ -9242,8 +9276,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92429276
VPInstruction *FinalReductionResult;
92439277
VPBuilder::InsertPointGuard Guard(Builder);
92449278
Builder.setInsertPoint(MiddleVPBB, IP);
9245-
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9246-
RdxDesc.getRecurrenceKind())) {
9279+
RecurKind Kind = RdxDesc.getRecurrenceKind();
9280+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind)) {
92479281
VPValue *Start = PhiR->getStartValue();
92489282
FinalReductionResult =
92499283
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
@@ -9254,6 +9288,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92549288
FinalReductionResult =
92559289
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
92569290
{PhiR, Start, NewExitingVPV}, ExitDL);
9291+
} else if (RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind)) {
9292+
// Mask out lanes that cannot be the index of the min/max value.
9293+
VPValue *Mask = IdxReductionMasks.at(PhiR);
9294+
Value *Iden = llvm::getRecurrenceIdentity(
9295+
Kind == RecurKind::MinMaxFirstIdx ? RecurKind::SMin : RecurKind::SMax,
9296+
PhiTy, RdxDesc.getFastMathFlags());
9297+
NewExitingVPV = Builder.createSelect(Mask, NewExitingVPV,
9298+
Plan->getOrAddLiveIn(Iden), ExitDL);
9299+
9300+
VPValue *Start = PhiR->getStartValue();
9301+
FinalReductionResult =
9302+
Builder.createNaryOp(VPInstruction::ComputeMinMaxIdxResult,
9303+
{PhiR, Start, NewExitingVPV}, ExitDL);
92579304
} else {
92589305
VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
92599306
RdxDesc.getRecurrenceKind())
@@ -9265,11 +9312,25 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92659312
}
92669313
// Update all users outside the vector region.
92679314
OrigExitingVPV->replaceUsesWithIf(
9268-
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
9315+
FinalReductionResult,
9316+
[FinalReductionResult, NewExitingVPV](VPUser &User, unsigned) {
92699317
auto *Parent = cast<VPRecipeBase>(&User)->getParent();
9270-
return FinalReductionResult != &User && !Parent->getParent();
9318+
return FinalReductionResult != &User &&
9319+
NewExitingVPV->getDefiningRecipe() != &User &&
9320+
!Parent->getParent();
92719321
});
92729322

9323+
// Generate a mask for the index reduction.
9324+
auto *Phi = cast<PHINode>(PhiR->getUnderlyingInstr());
9325+
if (Legal->isMinMaxRecurrence(Phi)) {
9326+
VPValue *IdxRdxMask = Builder.createICmp(CmpInst::ICMP_EQ, NewExitingVPV,
9327+
FinalReductionResult, ExitDL);
9328+
PHINode *IdxPhi = Legal->getMinMaxRecurrences().find(Phi)->second;
9329+
IdxReductionMasks.try_emplace(
9330+
cast<VPReductionPHIRecipe>(RecipeBuilder.getRecipe(IdxPhi)),
9331+
IdxRdxMask);
9332+
}
9333+
92739334
// Adjust AnyOf reductions; replace the reduction phi for the selected value
92749335
// with a boolean reduction phi node to check if the condition is true in
92759336
// any iteration. The final value is selected by the final
@@ -9304,11 +9365,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93049365
continue;
93059366
}
93069367

9307-
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9308-
RdxDesc.getRecurrenceKind())) {
9309-
// Adjust the start value for FindLastIV recurrences to use the sentinel
9310-
// value after generating the ResumePhi recipe, which uses the original
9311-
// start value.
9368+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) ||
9369+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind)) {
9370+
// Adjust the start value for FindLastIV/MinMaxIdx recurrences to use the
9371+
// sentinel value after generating the ResumePhi recipe, which uses the
9372+
// original start value.
93129373
PhiR->setOperand(0, Plan->getOrAddLiveIn(RdxDesc.getSentinelValue()));
93139374
}
93149375
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
909909
Broadcast,
910910
ComputeAnyOfResult,
911911
ComputeFindLastIVResult,
912+
ComputeMinMaxIdxResult,
912913
ComputeReductionResult,
913914
// Extracts the last lane from its operand if it is a vector, or the last
914915
// part if scalar. In the latter case, the recipe will be removed during

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
9191
return IntegerType::get(Ctx, 1);
9292
case VPInstruction::ComputeAnyOfResult:
9393
case VPInstruction::ComputeFindLastIVResult:
94+
case VPInstruction::ComputeMinMaxIdxResult:
9495
case VPInstruction::ComputeReductionResult: {
9596
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
9697
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

0 commit comments

Comments
 (0)