Skip to content

Commit 2512806

Browse files
authored
Merge 084746e into 41a4b04
2 parents 41a4b04 + 084746e commit 2512806

File tree

6 files changed

+398
-8
lines changed

6 files changed

+398
-8
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ enum class RecurKind {
5757
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
5858
///< (x,y) is increasing loop induction, and both x and y are
5959
///< integer type.
60+
MinMaxFirstIdx, ///< Integer Min/Max with first index
61+
MinMaxLastIdx, ///< Integer Min/Max with last index
6062
// clang-format on
6163
// TODO: Any_of and FindLast reduction need not be restricted to integer type
6264
// only.
@@ -210,6 +212,26 @@ class RecurrenceDescriptor {
210212
LLVM_ABI static bool isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
211213
DominatorTree *DT);
212214

215+
/// Returns the recurrence chain if \p Phi is an integer min/max recurrence in
216+
/// \p TheLoop. The RecurrenceDescriptor is returned in \p RecurDes.
217+
static SmallVector<Instruction *, 2>
218+
tryToGetMinMaxRecurrenceChain(PHINode *Phi, Loop *TheLoop,
219+
RecurrenceDescriptor &RecurDes);
220+
221+
/// Returns true if the recurrence is a min/max with index pattern, and
222+
/// updates the recurrence kind to RecurKind::MinMaxFirstIdx or
223+
/// RecurKind::MinMaxLastIdx.
224+
///
225+
/// \param IdxPhi The phi representing the index recurrence.
226+
/// \param MinMaxPhi The phi representing the min/max recurrence involved
227+
/// in the min/max with index pattern.
228+
/// \param MinMaxDesc The descriptor of the min/max recurrence.
229+
/// \param MinMaxChain The chain of instructions involved in the min/max
230+
/// recurrence.
231+
bool isMinMaxIdxReduction(PHINode *IdxPhi, PHINode *MinMaxPhi,
232+
const RecurrenceDescriptor &MinMaxDesc,
233+
ArrayRef<Instruction *> MinMaxChain);
234+
213235
RecurKind getRecurrenceKind() const { return Kind; }
214236

215237
unsigned getOpcode() const { return getOpcode(getRecurrenceKind()); }
@@ -263,6 +285,20 @@ class RecurrenceDescriptor {
263285
return Kind == RecurKind::FindLastIV;
264286
}
265287

288+
/// Returns true if the recurrence kind is of the form:
289+
/// select(icmp(a,b),x,y)
290+
/// where one of (x,y) is an increasing loop induction variable, and icmp(a,b)
291+
/// depends on a min/max recurrence.
292+
static bool isMinMaxIdxRecurrenceKind(RecurKind Kind) {
293+
return Kind == RecurKind::MinMaxFirstIdx ||
294+
Kind == RecurKind::MinMaxLastIdx;
295+
}
296+
297+
/// Returns true if the recurrence kind is an integer max kind.
298+
static bool isIntMaxRecurrenceKind(RecurKind Kind) {
299+
return Kind == RecurKind::UMax || Kind == RecurKind::SMax;
300+
}
301+
266302
/// Returns the type of the recurrence. This type can be narrower than the
267303
/// actual type of the Phi if the recurrence has been type-promoted.
268304
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ class LoopVectorizationLegality {
345345
/// Returns True if Phi is a fixed-order recurrence in this loop.
346346
bool isFixedOrderRecurrence(const PHINode *Phi) const;
347347

348+
/// Returns True if \p Phi is a min/max recurrence in this loop.
349+
bool isMinMaxRecurrence(const PHINode *Phi) const;
350+
348351
/// Return true if the block BB needs to be predicated in order for the loop
349352
/// to be vectorized.
350353
bool blockNeedsPredication(BasicBlock *BB) const;
@@ -519,6 +522,14 @@ class LoopVectorizationLegality {
519522
/// specific checks for outer loop vectorization.
520523
bool canVectorizeOuterLoop();
521524

525+
// Min/max recurrences can only be vectorized when involved in a min/max with
526+
// index reduction pattern. This function checks whether the \p Phi, which
527+
// represents the min/max recurrence, can be vectorized based on the given \p
528+
// Chain, which is the recurrence chain for the min/max recurrence. Returns
529+
// true if the min/max recurrence can be vectorized.
530+
bool canVectorizeMinMaxRecurrence(PHINode *Phi,
531+
ArrayRef<Instruction *> Chain);
532+
522533
/// Returns true if this is an early exit loop that can be vectorized.
523534
/// Currently, a loop with an uncountable early exit is considered
524535
/// vectorizable if:
@@ -606,6 +617,9 @@ class LoopVectorizationLegality {
606617
/// Holds the phi nodes that are fixed-order recurrences.
607618
RecurrenceSet FixedOrderRecurrences;
608619

620+
/// Holds the min/max recurrences variables.
621+
RecurrenceSet MinMaxRecurrences;
622+
609623
/// Holds the widest induction type encountered.
610624
IntegerType *WidestIndTy = nullptr;
611625

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5151
case RecurKind::UMin:
5252
case RecurKind::AnyOf:
5353
case RecurKind::FindLastIV:
54+
case RecurKind::MinMaxFirstIdx:
55+
case RecurKind::MinMaxLastIdx:
5456
return true;
5557
}
5658
return false;
@@ -1130,6 +1132,226 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
11301132
return true;
11311133
}
11321134

1135+
/// Return the recurrence kind if \p I is matched by the min/max operation
1136+
/// pattern. Otherwise, return RecurKind::None.
1137+
static RecurKind isMinMaxRecurOp(const Instruction *I) {
1138+
if (match(I, m_UMin(m_Value(), m_Value())))
1139+
return RecurKind::UMin;
1140+
if (match(I, m_UMax(m_Value(), m_Value())))
1141+
return RecurKind::UMax;
1142+
if (match(I, m_SMax(m_Value(), m_Value())))
1143+
return RecurKind::SMax;
1144+
if (match(I, m_SMin(m_Value(), m_Value())))
1145+
return RecurKind::SMin;
1146+
// TODO: support fp-min/max
1147+
return RecurKind::None;
1148+
}
1149+
1150+
SmallVector<Instruction *, 2>
1151+
RecurrenceDescriptor::tryToGetMinMaxRecurrenceChain(
1152+
PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RecurDes) {
1153+
SmallVector<Instruction *, 2> Chain;
1154+
// Check the phi is in the loop header and has two incoming values.
1155+
if (Phi->getParent() != TheLoop->getHeader() ||
1156+
Phi->getNumIncomingValues() != 2)
1157+
return {};
1158+
1159+
// Ensure the loop has a preheader and a latch block.
1160+
auto *Preheader = TheLoop->getLoopPreheader();
1161+
auto *Latch = TheLoop->getLoopLatch();
1162+
if (!Preheader || !Latch)
1163+
return {};
1164+
1165+
// Ensure that one of the incoming values of the PHI node is from the
1166+
// preheader, and the other one is from the loop latch.
1167+
if (Phi->getBasicBlockIndex(Preheader) < 0 ||
1168+
Phi->getBasicBlockIndex(Latch) < 0)
1169+
return {};
1170+
1171+
Value *StartValue = Phi->getIncomingValueForBlock(Preheader);
1172+
auto *BEValue = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch));
1173+
if (!BEValue || BEValue == Phi)
1174+
return {};
1175+
1176+
auto HasLoopExternalUse = [TheLoop](const Instruction *I) {
1177+
return any_of(I->users(), [TheLoop](auto *U) {
1178+
return !TheLoop->contains(cast<Instruction>(U));
1179+
});
1180+
};
1181+
1182+
// Ensure the recurrence phi has no users outside the loop, as such cases
1183+
// cannot be vectorized.
1184+
if (HasLoopExternalUse(Phi))
1185+
return {};
1186+
1187+
// Ensure the backedge value of the phi is only used internally by the phi;
1188+
// all other users must be outside the loop.
1189+
// TODO: support intermediate store.
1190+
if (any_of(BEValue->users(), [&](auto *U) {
1191+
auto *UI = cast<Instruction>(U);
1192+
return TheLoop->contains(UI) && UI != Phi;
1193+
}))
1194+
return {};
1195+
1196+
// Ensure the backedge value of the phi matches the min/max operation pattern.
1197+
RecurKind TargetKind = isMinMaxRecurOp(BEValue);
1198+
if (TargetKind == RecurKind::None)
1199+
return {};
1200+
1201+
// TODO: type-promoted recurrence
1202+
SmallPtrSet<Instruction *, 4> CastInsts;
1203+
1204+
// Trace the use-def chain from the backedge value to the phi, ensuring a
1205+
// unique in-loop path where all operations match the expected recurrence
1206+
// kind.
1207+
bool FoundRecurPhi = false;
1208+
SmallVector<Instruction *, 8> Worklist(1, BEValue);
1209+
SmallDenseMap<Instruction *, Instruction *, 4> VisitedFrom;
1210+
1211+
VisitedFrom.try_emplace(BEValue);
1212+
1213+
while (!Worklist.empty()) {
1214+
Instruction *Cur = Worklist.pop_back_val();
1215+
if (Cur == Phi) {
1216+
if (FoundRecurPhi)
1217+
return {};
1218+
FoundRecurPhi = true;
1219+
continue;
1220+
}
1221+
1222+
if (!TheLoop->contains(Cur))
1223+
continue;
1224+
1225+
// TODO: support the min/max recurrence in cmp-select pattern.
1226+
if (!isa<CallInst>(Cur) || isMinMaxRecurOp(Cur) != TargetKind)
1227+
continue;
1228+
1229+
for (Use &Op : Cur->operands()) {
1230+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
1231+
if (!VisitedFrom.try_emplace(OpInst, Cur).second)
1232+
return {};
1233+
Worklist.push_back(OpInst);
1234+
}
1235+
}
1236+
}
1237+
1238+
if (!FoundRecurPhi)
1239+
return {};
1240+
1241+
Instruction *ExitInstruction = nullptr;
1242+
// Get the recurrence chain by visited trace.
1243+
Instruction *VisitedInst = VisitedFrom.at(Phi);
1244+
while (VisitedInst) {
1245+
// Ensure that no instruction in the recurrence chain is used outside the
1246+
// loop, except for the backedge value, which is permitted.
1247+
if (HasLoopExternalUse(VisitedInst)) {
1248+
if (VisitedInst != BEValue)
1249+
return {};
1250+
ExitInstruction = BEValue;
1251+
}
1252+
Chain.push_back(VisitedInst);
1253+
VisitedInst = VisitedFrom.at(VisitedInst);
1254+
}
1255+
1256+
RecurDes = RecurrenceDescriptor(
1257+
StartValue, ExitInstruction, /*IntermediateStore=*/nullptr, TargetKind,
1258+
FastMathFlags(), /*ExactFPMathInst=*/nullptr, Phi->getType(),
1259+
/*IsSigned=*/false, /*IsOrdered=*/false, CastInsts,
1260+
/*MinWidthCastToRecurTy=*/-1U);
1261+
1262+
LLVM_DEBUG(dbgs() << "Found a min/max recurrence PHI: " << *Phi << "\n");
1263+
1264+
return Chain;
1265+
}
1266+
1267+
bool RecurrenceDescriptor::isMinMaxIdxReduction(
1268+
PHINode *IdxPhi, PHINode *MinMaxPhi, const RecurrenceDescriptor &MinMaxDesc,
1269+
ArrayRef<Instruction *> MinMaxChain) {
1270+
// Return early if the recurrence kind is already known to be min/max with
1271+
// index.
1272+
if (isMinMaxIdxRecurrenceKind(Kind))
1273+
return true;
1274+
1275+
if (!isFindLastIVRecurrenceKind(Kind))
1276+
return false;
1277+
1278+
// Ensure index reduction phi and min/max recurrence phi are in the same basic
1279+
// block.
1280+
if (IdxPhi->getParent() != MinMaxPhi->getParent())
1281+
return false;
1282+
1283+
RecurKind MinMaxRK = MinMaxDesc.getRecurrenceKind();
1284+
// TODO: support floating-point min/max with index.
1285+
if (!isIntMinMaxRecurrenceKind(MinMaxRK))
1286+
return false;
1287+
1288+
// FindLastIV only supports a single select operation in the recurrence chain
1289+
// so far. Therefore, do not consider min/max recurrences with more than one
1290+
// operation in the recurrence chain.
1291+
// TODO: support FindLastIV with multiple operations in the recurrence chain.
1292+
if (MinMaxChain.size() != 1)
1293+
return false;
1294+
1295+
Instruction *MinMaxChainCur = MinMaxPhi;
1296+
Instruction *MinMaxChainNext = MinMaxChain.front();
1297+
Value *OutOfChain;
1298+
bool IsMinMaxOperation = match(
1299+
MinMaxChainNext,
1300+
m_CombineOr(m_MaxOrMin(m_Specific(MinMaxChainCur), m_Value(OutOfChain)),
1301+
m_MaxOrMin(m_Value(OutOfChain), m_Specific(MinMaxChainCur))));
1302+
assert(IsMinMaxOperation && "Unexpected operation in the recurrence chain");
1303+
1304+
auto *IdxExit = cast<SelectInst>(LoopExitInstr);
1305+
Value *IdxCond = IdxExit->getCondition();
1306+
// Check if the operands used by cmp instruction of index select is the same
1307+
// as the operands used by min/max recurrence.
1308+
bool IsMatchLHSInMinMaxChain =
1309+
match(IdxCond, m_Cmp(m_Specific(MinMaxChainCur), m_Specific(OutOfChain)));
1310+
bool IsMatchRHSInMinMaxChain =
1311+
match(IdxCond, m_Cmp(m_Specific(OutOfChain), m_Specific(MinMaxChainCur)));
1312+
if (!IsMatchLHSInMinMaxChain && !IsMatchRHSInMinMaxChain)
1313+
return false;
1314+
1315+
CmpInst::Predicate IdxPred = cast<CmpInst>(IdxCond)->getPredicate();
1316+
// The predicate of cmp instruction must be relational in min/max with index.
1317+
if (CmpInst::isEquality(IdxPred))
1318+
return false;
1319+
1320+
// Normalize predicate from
1321+
// m_Cmp(pred, out_of_chain, in_chain)
1322+
// to
1323+
// m_Cmp(swapped_pred, in_chain, out_of_chain).
1324+
if (IsMatchRHSInMinMaxChain)
1325+
IdxPred = CmpInst::getSwappedPredicate(IdxPred);
1326+
1327+
// Verify that the select operation is updated on the correct side based on
1328+
// the min/max kind.
1329+
bool IsTrueUpdateIdx = IdxExit->getFalseValue() == IdxPhi;
1330+
bool IsMaxRK = isIntMaxRecurrenceKind(MinMaxRK);
1331+
bool IsLess = ICmpInst::isLT(IdxPred) || ICmpInst::isLE(IdxPred);
1332+
bool IsExpectedTrueUpdateIdx = IsMaxRK == IsLess;
1333+
if (IsTrueUpdateIdx != IsExpectedTrueUpdateIdx)
1334+
return false;
1335+
1336+
RecurKind NewIdxRK;
1337+
// The index recurrence kind is the same for both the predicate and its
1338+
// inverse.
1339+
if (!IsLess)
1340+
IdxPred = CmpInst::getInversePredicate(IdxPred);
1341+
// For max recurrence, a strict less-than predicate indicates that the first
1342+
// matching index will be selected. For min recurrence, the opposite holds.
1343+
NewIdxRK = IsMaxRK != ICmpInst::isLE(IdxPred) ? RecurKind::MinMaxFirstIdx
1344+
: RecurKind::MinMaxLastIdx;
1345+
1346+
// Update the kind of index recurrence.
1347+
Kind = NewIdxRK;
1348+
LLVM_DEBUG(
1349+
dbgs() << "Found a min/max with "
1350+
<< (NewIdxRK == RecurKind::MinMaxFirstIdx ? "first" : "last")
1351+
<< " index reduction PHI." << *IdxPhi << "\n");
1352+
return true;
1353+
}
1354+
11331355
unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11341356
switch (Kind) {
11351357
case RecurKind::Add:

0 commit comments

Comments
 (0)