Skip to content

Commit f5c4c14

Browse files
committed
[LoopVectorize] Vectorize the reduction pattern of integer min/max with index. (2/2)
1 parent 60e3dd7 commit f5c4c14

File tree

11 files changed

+1689
-81
lines changed

11 files changed

+1689
-81
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ class RecurrenceDescriptor {
305305
/// Returns the sentinel value for FindLastIV recurrences to replace the start
306306
/// value.
307307
Value *getSentinelValue() const {
308-
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
308+
assert(
309+
(isFindLastIVRecurrenceKind(Kind) || isMinMaxIdxRecurrenceKind(Kind)) &&
310+
"Unexpected recurrence kind");
309311
Type *Ty = StartValue->getType();
310312
return ConstantInt::get(Ty,
311313
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src, Value *InitVal,
423423
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
424424
Value *Sentinel);
425425

426+
/// Create a reduction of the given vector \p Src for a reduction of the
427+
/// kind RecurKind::MinMaxFirstIdx or RecurKind::MinMaxLastIdx. The reduction
428+
/// operation is described by \p Desc.
429+
Value *createMinMaxIdxReduction(IRBuilderBase &B, Value *Src, Value *Start,
430+
const RecurrenceDescriptor &Desc);
431+
426432
/// Create an ordered reduction intrinsic using the given recurrence
427433
/// kind \p RdxKind.
428434
Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ class LoopVectorizationLegality {
307307
/// Return the fixed-order recurrences found in the loop.
308308
RecurrenceSet &getFixedOrderRecurrences() { return FixedOrderRecurrences; }
309309

310+
/// Return the min/max recurrences found in the loop.
311+
const SmallDenseMap<PHINode *, PHINode *> &getMinMaxRecurrences() {
312+
return MinMaxRecurrences;
313+
}
314+
310315
/// Returns the widest induction type.
311316
IntegerType *getWidestInductionType() { return WidestIndTy; }
312317

@@ -618,7 +623,7 @@ class LoopVectorizationLegality {
618623
RecurrenceSet FixedOrderRecurrences;
619624

620625
/// Holds the min/max recurrences variables.
621-
RecurrenceSet MinMaxRecurrences;
626+
SmallDenseMap<PHINode *, PHINode *> MinMaxRecurrences;
622627

623628
/// Holds the widest induction type encountered.
624629
IntegerType *WidestIndTy = nullptr;

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,25 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12481248
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
12491249
}
12501250

1251+
Value *llvm::createMinMaxIdxReduction(IRBuilderBase &Builder, Value *Src,
1252+
Value *Start,
1253+
const RecurrenceDescriptor &Desc) {
1254+
RecurKind Kind = Desc.getRecurrenceKind();
1255+
assert(RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1256+
"Unexpected reduction kind");
1257+
Value *Sentinel = Desc.getSentinelValue();
1258+
Value *Rdx = Src;
1259+
if (Src->getType()->isVectorTy())
1260+
Rdx = Kind == RecurKind::MinMaxFirstIdx
1261+
? Builder.CreateIntMinReduce(Src, true)
1262+
: Builder.CreateIntMaxReduce(Src, true);
1263+
// Correct the final reduction result back to the start value if the reduction
1264+
// result is sentinel value.
1265+
Value *Cmp =
1266+
Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Sentinel, "rdx.select.cmp");
1267+
return Builder.CreateSelect(Cmp, Rdx, Start, "rdx.select");
1268+
}
1269+
12511270
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
12521271
FastMathFlags Flags) {
12531272
bool Negative = false;
@@ -1336,7 +1355,8 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
13361355
RecurKind Kind) {
13371356
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
13381357
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
1339-
"AnyOf or FindLastIV reductions are not supported.");
1358+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1359+
"AnyOf, FindLastIV and MinMaxIdx reductions are not supported.");
13401360
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13411361
auto *SrcTy = cast<VectorType>(Src->getType());
13421362
Type *SrcEltTy = SrcTy->getElementType();

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
851851
if (MinMaxRecurDes.getLoopExitInstr())
852852
AllowedExit.insert(MinMaxRecurDes.getLoopExitInstr());
853853
Reductions[Phi] = MinMaxRecurDes;
854-
MinMaxRecurrences.insert(Phi);
854+
MinMaxRecurrences.try_emplace(Phi);
855855
MinMaxRecurrenceChains[Phi] = std::move(Chain);
856856
continue;
857857
}
@@ -1093,10 +1093,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
10931093
if (!canVectorizeMinMaxRecurrence(Phi, Chain))
10941094
return false;
10951095
}
1096-
// FIXME: Remove this after the IR generation of min/max with index is
1097-
// supported.
1098-
if (!MinMaxRecurrences.empty())
1099-
return false;
11001096

11011097
return true;
11021098
}
@@ -1106,6 +1102,10 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11061102
assert(!Chain.empty() && "Unexpected empty recurrence chain");
11071103
assert(isMinMaxRecurrence(Phi) && "The PHI is not a min/max recurrence phi");
11081104

1105+
auto It = MinMaxRecurrences.find(Phi);
1106+
if (It->second)
1107+
return true;
1108+
11091109
auto IsMinMaxIdxReductionPhi = [this, Phi, &Chain](Value *Candidate) -> bool {
11101110
auto *IdxPhi = dyn_cast<PHINode>(Candidate);
11111111
if (!IdxPhi || !isReductionVariable(IdxPhi))
@@ -1150,7 +1150,17 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11501150

11511151
auto *TrueVal = IdxChainHead->getTrueValue();
11521152
auto *FalseVal = IdxChainHead->getFalseValue();
1153-
return IsMinMaxIdxReductionPhi(TrueVal) || IsMinMaxIdxReductionPhi(FalseVal);
1153+
PHINode *IdxPhi;
1154+
if (IsMinMaxIdxReductionPhi(TrueVal))
1155+
IdxPhi = cast<PHINode>(TrueVal);
1156+
else if (IsMinMaxIdxReductionPhi(FalseVal))
1157+
IdxPhi = cast<PHINode>(FalseVal);
1158+
else
1159+
return false;
1160+
1161+
// Record the index reduction phi uses the min/max recurrence.
1162+
It->second = IdxPhi;
1163+
return true;
11541164
}
11551165

11561166
/// Find histogram operations that match high-level code in loops:
@@ -1973,7 +1983,8 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
19731983
SmallPtrSet<const Value *, 8> ReductionLiveOuts;
19741984

19751985
for (const auto &Reduction : getReductionVars())
1976-
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
1986+
if (auto *ExitInstr = Reduction.second.getLoopExitInstr())
1987+
ReductionLiveOuts.insert(ExitInstr);
19771988

19781989
// TODO: handle non-reduction outside users when tail is folded by masking.
19791990
for (auto *AE : AllowedExit) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4457,6 +4457,14 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
44574457
return false;
44584458
}
44594459

4460+
// TODO: support epilogue vectorization for min/max with index.
4461+
if (any_of(Legal->getReductionVars(), [](const auto &Reduction) {
4462+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
4463+
return RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(
4464+
RdxDesc.getRecurrenceKind());
4465+
}))
4466+
return false;
4467+
44604468
// Epilogue vectorization code has not been auditted to ensure it handles
44614469
// non-latch exits properly. It may be fine, but it needs auditted and
44624470
// tested.
@@ -4901,7 +4909,8 @@ LoopVectorizationCostModel::selectInterleaveCount(VPlan &Plan, ElementCount VF,
49014909
const RecurrenceDescriptor &RdxDesc = Reduction.second;
49024910
RecurKind RK = RdxDesc.getRecurrenceKind();
49034911
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
4904-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
4912+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) ||
4913+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK);
49054914
});
49064915
if (HasSelectCmpReductions) {
49074916
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -6618,6 +6627,10 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
66186627

66196628
for (const auto &Reduction : Legal->getReductionVars()) {
66206629
PHINode *Phi = Reduction.first;
6630+
// TODO: support in-loop min/max with index.
6631+
if (Legal->isMinMaxRecurrence(Phi))
6632+
continue;
6633+
66216634
const RecurrenceDescriptor &RdxDesc = Reduction.second;
66226635

66236636
// We don't collect reductions that are type promoted (yet).
@@ -7233,6 +7246,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
72337246
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
72347247
return;
72357248

7249+
assert(EpiRedResult->getOpcode() != VPInstruction::ComputeMinMaxIdxResult);
7250+
72367251
auto *EpiRedHeaderPhi =
72377252
cast<VPReductionPHIRecipe>(EpiRedResult->getOperand(0));
72387253
const RecurrenceDescriptor &RdxDesc =
@@ -8140,10 +8155,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
81408155
// Find all possible partial reductions.
81418156
SmallVector<std::pair<PartialReductionChain, unsigned>>
81428157
PartialReductionChains;
8143-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8144-
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8145-
PartialReductionChains);
8146-
}
8158+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8159+
if (auto *ExitInstr = RdxDesc.getLoopExitInstr())
8160+
getScaledReductions(Phi, ExitInstr, Range, PartialReductionChains);
81478161

81488162
// A partial reduction is invalid if any of its extends are used by
81498163
// something that isn't another partial reduction. This is because the
@@ -9037,6 +9051,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
90379051
assert(
90389052
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
90399053
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9054+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
90409055
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
90419056

90429057
// Collect the chain of "link" recipes for the reduction starting at PhiR.
@@ -9160,15 +9175,32 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91609175
PreviousLink = RedRecipe;
91619176
}
91629177
}
9178+
9179+
// Collect all VPReductionPHIRecipes in the header block, and sort them based
9180+
// on the dependency order of the reductions. This ensures that results of
9181+
// min/max reductions are computed before their corresponding index
9182+
// reductions, since the index reduction relies on the result of the min/max
9183+
// reduction to determine which lane produced the min/max.
9184+
SmallVector<VPReductionPHIRecipe *> VPReductionPHIs;
9185+
for (VPRecipeBase &R : Header->phis())
9186+
if (auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R))
9187+
VPReductionPHIs.push_back(PhiR);
9188+
9189+
stable_sort(VPReductionPHIs, [this](const VPReductionPHIRecipe *R1,
9190+
const VPReductionPHIRecipe *R2) {
9191+
auto *Phi1 = cast<PHINode>(R1->getUnderlyingInstr());
9192+
if (!Legal->isMinMaxRecurrence(Phi1))
9193+
return false;
9194+
9195+
auto *Phi2 = cast<PHINode>(R2->getUnderlyingInstr());
9196+
return Legal->getMinMaxRecurrences().find(Phi1)->second == Phi2;
9197+
});
9198+
91639199
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
91649200
Builder.setInsertPoint(&*std::prev(std::prev(LatchVPBB->end())));
91659201
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
9166-
for (VPRecipeBase &R :
9167-
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
9168-
VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
9169-
if (!PhiR)
9170-
continue;
9171-
9202+
SmallDenseMap<VPReductionPHIRecipe *, VPValue *> IdxReductionMasks;
9203+
for (auto *PhiR : VPReductionPHIs) {
91729204
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
91739205
Type *PhiTy = PhiR->getUnderlyingValue()->getType();
91749206
// If tail is folded by masking, introduce selects between the phi
@@ -9195,7 +9227,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
91959227
cast<VPInstruction>(&U)->getOpcode() ==
91969228
VPInstruction::ComputeReductionResult ||
91979229
cast<VPInstruction>(&U)->getOpcode() ==
9198-
VPInstruction::ComputeFindLastIVResult);
9230+
VPInstruction::ComputeFindLastIVResult ||
9231+
cast<VPInstruction>(&U)->getOpcode() ==
9232+
VPInstruction::ComputeMinMaxIdxResult);
91999233
});
92009234
if (CM.usePredicatedReductionSelect())
92019235
PhiR->setOperand(1, NewExitingVPV);
@@ -9239,6 +9273,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92399273
VPInstruction *FinalReductionResult;
92409274
VPBuilder::InsertPointGuard Guard(Builder);
92419275
Builder.setInsertPoint(MiddleVPBB, IP);
9276+
RecurKind RK = RdxDesc.getRecurrenceKind();
92429277
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
92439278
RdxDesc.getRecurrenceKind())) {
92449279
VPValue *Start = PhiR->getStartValue();
@@ -9251,6 +9286,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92519286
FinalReductionResult =
92529287
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
92539288
{PhiR, Start, NewExitingVPV}, ExitDL);
9289+
} else if (RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK)) {
9290+
// Mask out lanes that cannot be the index of the min/max value.
9291+
VPValue *Mask = IdxReductionMasks.at(PhiR);
9292+
Value *Iden = llvm::getRecurrenceIdentity(
9293+
RK == RecurKind::MinMaxFirstIdx ? RecurKind::SMin : RecurKind::SMax,
9294+
PhiTy, RdxDesc.getFastMathFlags());
9295+
NewExitingVPV = Builder.createSelect(Mask, NewExitingVPV,
9296+
Plan->getOrAddLiveIn(Iden), ExitDL);
9297+
9298+
VPValue *Start = PhiR->getStartValue();
9299+
FinalReductionResult =
9300+
Builder.createNaryOp(VPInstruction::ComputeMinMaxIdxResult,
9301+
{PhiR, Start, NewExitingVPV}, ExitDL);
92549302
} else {
92559303
VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
92569304
RdxDesc.getRecurrenceKind())
@@ -9262,11 +9310,25 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
92629310
}
92639311
// Update all users outside the vector region.
92649312
OrigExitingVPV->replaceUsesWithIf(
9265-
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
9313+
FinalReductionResult,
9314+
[FinalReductionResult, NewExitingVPV](VPUser &User, unsigned) {
92669315
auto *Parent = cast<VPRecipeBase>(&User)->getParent();
9267-
return FinalReductionResult != &User && !Parent->getParent();
9316+
return FinalReductionResult != &User &&
9317+
NewExitingVPV->getDefiningRecipe() != &User &&
9318+
!Parent->getParent();
92689319
});
92699320

9321+
// Generate a mask for the index reduction.
9322+
auto *Phi = cast<PHINode>(PhiR->getUnderlyingInstr());
9323+
if (Legal->isMinMaxRecurrence(Phi)) {
9324+
VPValue *IdxRdxMask = Builder.createICmp(CmpInst::ICMP_EQ, NewExitingVPV,
9325+
FinalReductionResult, ExitDL);
9326+
PHINode *IdxPhi = Legal->getMinMaxRecurrences().find(Phi)->second;
9327+
IdxReductionMasks.try_emplace(
9328+
cast<VPReductionPHIRecipe>(RecipeBuilder.getRecipe(IdxPhi)),
9329+
IdxRdxMask);
9330+
}
9331+
92709332
// Adjust AnyOf reductions; replace the reduction phi for the selected value
92719333
// with a boolean reduction phi node to check if the condition is true in
92729334
// any iteration. The final value is selected by the final
@@ -9301,16 +9363,17 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93019363
continue;
93029364
}
93039365

9304-
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9305-
RdxDesc.getRecurrenceKind())) {
9306-
// Adjust the start value for FindLastIV recurrences to use the sentinel
9307-
// value after generating the ResumePhi recipe, which uses the original
9308-
// start value.
9366+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) ||
9367+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK)) {
9368+
// Adjust the start value for FindLastIV/MinMaxIdx recurrences to use the
9369+
// sentinel value after generating the ResumePhi recipe, which uses the
9370+
// original start value.
93099371
PhiR->setOperand(0, Plan->getOrAddLiveIn(RdxDesc.getSentinelValue()));
93109372
}
9311-
RecurKind RK = RdxDesc.getRecurrenceKind();
9373+
93129374
if ((!RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) &&
93139375
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
9376+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK) &&
93149377
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))) {
93159378
VPBuilder PHBuilder(Plan->getVectorPreheader());
93169379
VPValue *Iden = Plan->getOrAddLiveIn(

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
909909
Broadcast,
910910
ComputeAnyOfResult,
911911
ComputeFindLastIVResult,
912+
ComputeMinMaxIdxResult,
912913
ComputeReductionResult,
913914
// Extracts the last lane from its operand if it is a vector, or the last
914915
// part if scalar. In the latter case, the recipe will be removed during

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
9292
return IntegerType::get(Ctx, 1);
9393
case VPInstruction::ComputeAnyOfResult:
9494
case VPInstruction::ComputeFindLastIVResult:
95+
case VPInstruction::ComputeMinMaxIdxResult:
9596
case VPInstruction::ComputeReductionResult: {
9697
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
9798
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

0 commit comments

Comments
 (0)