@@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
51
51
case RecurKind::UMin:
52
52
case RecurKind::AnyOf:
53
53
case RecurKind::FindLastIV:
54
+ case RecurKind::MinMaxFirstIdx:
55
+ case RecurKind::MinMaxLastIdx:
54
56
return true ;
55
57
}
56
58
return false ;
@@ -1130,6 +1132,226 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
1130
1132
return true ;
1131
1133
}
1132
1134
1135
+ // / Return the recurrence kind if \p I is matched by the min/max operation
1136
+ // / pattern. Otherwise, return RecurKind::None.
1137
+ static RecurKind isMinMaxRecurOp (const Instruction *I) {
1138
+ if (match (I, m_UMin (m_Value (), m_Value ())))
1139
+ return RecurKind::UMin;
1140
+ if (match (I, m_UMax (m_Value (), m_Value ())))
1141
+ return RecurKind::UMax;
1142
+ if (match (I, m_SMax (m_Value (), m_Value ())))
1143
+ return RecurKind::SMax;
1144
+ if (match (I, m_SMin (m_Value (), m_Value ())))
1145
+ return RecurKind::SMin;
1146
+ // TODO: support fp-min/max
1147
+ return RecurKind::None;
1148
+ }
1149
+
1150
+ SmallVector<Instruction *, 2 >
1151
+ RecurrenceDescriptor::tryToGetMinMaxRecurrenceChain (
1152
+ PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RecurDes) {
1153
+ SmallVector<Instruction *, 2 > Chain;
1154
+ // Check the phi is in the loop header and has two incoming values.
1155
+ if (Phi->getParent () != TheLoop->getHeader () ||
1156
+ Phi->getNumIncomingValues () != 2 )
1157
+ return {};
1158
+
1159
+ // Ensure the loop has a preheader and a latch block.
1160
+ auto *Preheader = TheLoop->getLoopPreheader ();
1161
+ auto *Latch = TheLoop->getLoopLatch ();
1162
+ if (!Preheader || !Latch)
1163
+ return {};
1164
+
1165
+ // Ensure that one of the incoming values of the PHI node is from the
1166
+ // preheader, and the other one is from the loop latch.
1167
+ if (Phi->getBasicBlockIndex (Preheader) < 0 ||
1168
+ Phi->getBasicBlockIndex (Latch) < 0 )
1169
+ return {};
1170
+
1171
+ Value *StartValue = Phi->getIncomingValueForBlock (Preheader);
1172
+ auto *BEValue = dyn_cast<Instruction>(Phi->getIncomingValueForBlock (Latch));
1173
+ if (!BEValue || BEValue == Phi)
1174
+ return {};
1175
+
1176
+ auto HasLoopExternalUse = [TheLoop](const Instruction *I) {
1177
+ return any_of (I->users (), [TheLoop](auto *U) {
1178
+ return !TheLoop->contains (cast<Instruction>(U));
1179
+ });
1180
+ };
1181
+
1182
+ // Ensure the recurrence phi has no users outside the loop, as such cases
1183
+ // cannot be vectorized.
1184
+ if (HasLoopExternalUse (Phi))
1185
+ return {};
1186
+
1187
+ // Ensure the backedge value of the phi is only used internally by the phi;
1188
+ // all other users must be outside the loop.
1189
+ // TODO: support intermediate store.
1190
+ if (any_of (BEValue->users (), [&](auto *U) {
1191
+ auto *UI = cast<Instruction>(U);
1192
+ return TheLoop->contains (UI) && UI != Phi;
1193
+ }))
1194
+ return {};
1195
+
1196
+ // Ensure the backedge value of the phi matches the min/max operation pattern.
1197
+ RecurKind TargetKind = isMinMaxRecurOp (BEValue);
1198
+ if (TargetKind == RecurKind::None)
1199
+ return {};
1200
+
1201
+ // TODO: type-promoted recurrence
1202
+ SmallPtrSet<Instruction *, 4 > CastInsts;
1203
+
1204
+ // Trace the use-def chain from the backedge value to the phi, ensuring a
1205
+ // unique in-loop path where all operations match the expected recurrence
1206
+ // kind.
1207
+ bool FoundRecurPhi = false ;
1208
+ SmallVector<Instruction *, 8 > Worklist (1 , BEValue);
1209
+ SmallDenseMap<Instruction *, Instruction *, 4 > VisitedFrom;
1210
+
1211
+ VisitedFrom.try_emplace (BEValue);
1212
+
1213
+ while (!Worklist.empty ()) {
1214
+ Instruction *Cur = Worklist.pop_back_val ();
1215
+ if (Cur == Phi) {
1216
+ if (FoundRecurPhi)
1217
+ return {};
1218
+ FoundRecurPhi = true ;
1219
+ continue ;
1220
+ }
1221
+
1222
+ if (!TheLoop->contains (Cur))
1223
+ continue ;
1224
+
1225
+ // TODO: support the min/max recurrence in cmp-select pattern.
1226
+ if (!isa<CallInst>(Cur) || isMinMaxRecurOp (Cur) != TargetKind)
1227
+ continue ;
1228
+
1229
+ for (Use &Op : Cur->operands ()) {
1230
+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
1231
+ if (!VisitedFrom.try_emplace (OpInst, Cur).second )
1232
+ return {};
1233
+ Worklist.push_back (OpInst);
1234
+ }
1235
+ }
1236
+ }
1237
+
1238
+ if (!FoundRecurPhi)
1239
+ return {};
1240
+
1241
+ Instruction *ExitInstruction = nullptr ;
1242
+ // Get the recurrence chain by visited trace.
1243
+ Instruction *VisitedInst = VisitedFrom.at (Phi);
1244
+ while (VisitedInst) {
1245
+ // Ensure that no instruction in the recurrence chain is used outside the
1246
+ // loop, except for the backedge value, which is permitted.
1247
+ if (HasLoopExternalUse (VisitedInst)) {
1248
+ if (VisitedInst != BEValue)
1249
+ return {};
1250
+ ExitInstruction = BEValue;
1251
+ }
1252
+ Chain.push_back (VisitedInst);
1253
+ VisitedInst = VisitedFrom.at (VisitedInst);
1254
+ }
1255
+
1256
+ RecurDes = RecurrenceDescriptor (
1257
+ StartValue, ExitInstruction, /* IntermediateStore=*/ nullptr , TargetKind,
1258
+ FastMathFlags (), /* ExactFPMathInst=*/ nullptr , Phi->getType (),
1259
+ /* IsSigned=*/ false , /* IsOrdered=*/ false , CastInsts,
1260
+ /* MinWidthCastToRecurTy=*/ -1U );
1261
+
1262
+ LLVM_DEBUG (dbgs () << " Found a min/max recurrence PHI: " << *Phi << " \n " );
1263
+
1264
+ return Chain;
1265
+ }
1266
+
1267
+ bool RecurrenceDescriptor::isMinMaxIdxReduction (
1268
+ PHINode *IdxPhi, PHINode *MinMaxPhi, const RecurrenceDescriptor &MinMaxDesc,
1269
+ ArrayRef<Instruction *> MinMaxChain) {
1270
+ // Return early if the recurrence kind is already known to be min/max with
1271
+ // index.
1272
+ if (isMinMaxIdxRecurrenceKind (Kind))
1273
+ return true ;
1274
+
1275
+ if (!isFindLastIVRecurrenceKind (Kind))
1276
+ return false ;
1277
+
1278
+ // Ensure index reduction phi and min/max recurrence phi are in the same basic
1279
+ // block.
1280
+ if (IdxPhi->getParent () != MinMaxPhi->getParent ())
1281
+ return false ;
1282
+
1283
+ RecurKind MinMaxRK = MinMaxDesc.getRecurrenceKind ();
1284
+ // TODO: support floating-point min/max with index.
1285
+ if (!isIntMinMaxRecurrenceKind (MinMaxRK))
1286
+ return false ;
1287
+
1288
+ // FindLastIV only supports a single select operation in the recurrence chain
1289
+ // so far. Therefore, do not consider min/max recurrences with more than one
1290
+ // operation in the recurrence chain.
1291
+ // TODO: support FindLastIV with multiple operations in the recurrence chain.
1292
+ if (MinMaxChain.size () != 1 )
1293
+ return false ;
1294
+
1295
+ Instruction *MinMaxChainCur = MinMaxPhi;
1296
+ Instruction *MinMaxChainNext = MinMaxChain.front ();
1297
+ Value *OutOfChain;
1298
+ bool IsMinMaxOperation = match (
1299
+ MinMaxChainNext,
1300
+ m_CombineOr (m_MaxOrMin (m_Specific (MinMaxChainCur), m_Value (OutOfChain)),
1301
+ m_MaxOrMin (m_Value (OutOfChain), m_Specific (MinMaxChainCur))));
1302
+ assert (IsMinMaxOperation && " Unexpected operation in the recurrence chain" );
1303
+
1304
+ auto *IdxExit = cast<SelectInst>(LoopExitInstr);
1305
+ Value *IdxCond = IdxExit->getCondition ();
1306
+ // Check if the operands used by cmp instruction of index select is the same
1307
+ // as the operands used by min/max recurrence.
1308
+ bool IsMatchLHSInMinMaxChain =
1309
+ match (IdxCond, m_Cmp (m_Specific (MinMaxChainCur), m_Specific (OutOfChain)));
1310
+ bool IsMatchRHSInMinMaxChain =
1311
+ match (IdxCond, m_Cmp (m_Specific (OutOfChain), m_Specific (MinMaxChainCur)));
1312
+ if (!IsMatchLHSInMinMaxChain && !IsMatchRHSInMinMaxChain)
1313
+ return false ;
1314
+
1315
+ CmpInst::Predicate IdxPred = cast<CmpInst>(IdxCond)->getPredicate ();
1316
+ // The predicate of cmp instruction must be relational in min/max with index.
1317
+ if (CmpInst::isEquality (IdxPred))
1318
+ return false ;
1319
+
1320
+ // Normalize predicate from
1321
+ // m_Cmp(pred, out_of_chain, in_chain)
1322
+ // to
1323
+ // m_Cmp(swapped_pred, in_chain, out_of_chain).
1324
+ if (IsMatchRHSInMinMaxChain)
1325
+ IdxPred = CmpInst::getSwappedPredicate (IdxPred);
1326
+
1327
+ // Verify that the select operation is updated on the correct side based on
1328
+ // the min/max kind.
1329
+ bool IsTrueUpdateIdx = IdxExit->getFalseValue () == IdxPhi;
1330
+ bool IsMaxRK = isIntMaxRecurrenceKind (MinMaxRK);
1331
+ bool IsLess = ICmpInst::isLT (IdxPred) || ICmpInst::isLE (IdxPred);
1332
+ bool IsExpectedTrueUpdateIdx = IsMaxRK == IsLess;
1333
+ if (IsTrueUpdateIdx != IsExpectedTrueUpdateIdx)
1334
+ return false ;
1335
+
1336
+ RecurKind NewIdxRK;
1337
+ // The index recurrence kind is the same for both the predicate and its
1338
+ // inverse.
1339
+ if (!IsLess)
1340
+ IdxPred = CmpInst::getInversePredicate (IdxPred);
1341
+ // For max recurrence, a strict less-than predicate indicates that the first
1342
+ // matching index will be selected. For min recurrence, the opposite holds.
1343
+ NewIdxRK = IsMaxRK != ICmpInst::isLE (IdxPred) ? RecurKind::MinMaxFirstIdx
1344
+ : RecurKind::MinMaxLastIdx;
1345
+
1346
+ // Update the kind of index recurrence.
1347
+ Kind = NewIdxRK;
1348
+ LLVM_DEBUG (
1349
+ dbgs () << " Found a min/max with "
1350
+ << (NewIdxRK == RecurKind::MinMaxFirstIdx ? " first" : " last" )
1351
+ << " index reduction PHI." << *IdxPhi << " \n " );
1352
+ return true ;
1353
+ }
1354
+
1133
1355
unsigned RecurrenceDescriptor::getOpcode (RecurKind Kind) {
1134
1356
switch (Kind) {
1135
1357
case RecurKind::Add:
0 commit comments