@@ -992,8 +992,10 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
992
992
bool LoopIdiomVectorize::recognizeFindFirstByte () {
993
993
// Currently the transformation only works on scalable vector types, although
994
994
// there is no fundamental reason why it cannot be made to work for fixed
995
- // vectors too.
996
- if (!TTI->supportsScalableVectors () || DisableFindFirstByte)
995
+ // vectors. We also need to know the target's minimum page size in order to
996
+ // generate runtime memory checks to ensure the vector version won't fault.
997
+ if (!TTI->supportsScalableVectors () || !TTI->getMinPageSize ().has_value () ||
998
+ DisableFindFirstByte)
997
999
return false ;
998
1000
999
1001
// Define some constants we need throughout.
@@ -1049,30 +1051,33 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
1049
1051
// %22 = icmp eq i8 %15, %21
1050
1052
// br i1 %22, label %ExitSucc, label %InnerBB
1051
1053
BasicBlock *ExitSucc, *InnerBB;
1052
- Value *LoadA , *LoadB ;
1053
- ICmpInst::Predicate MatchPred;
1054
+ Value *LoadSearch , *LoadNeedle ;
1055
+ CmpPredicate MatchPred;
1054
1056
if (!match (MatchBB->getTerminator (),
1055
- m_Br (m_ICmp (MatchPred, m_Value (LoadA ), m_Value (LoadB )),
1057
+ m_Br (m_ICmp (MatchPred, m_Value (LoadSearch ), m_Value (LoadNeedle )),
1056
1058
m_BasicBlock (ExitSucc), m_BasicBlock (InnerBB))) ||
1057
- MatchPred != ICmpInst::Predicate::ICMP_EQ ||
1058
- !InnerLoop->contains (InnerBB))
1059
+ MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains (InnerBB))
1059
1060
return false ;
1060
1061
1061
1062
// We expect outside uses of `IndPhi' in ExitSucc (and only there).
1062
1063
for (User *U : IndPhi->users ())
1063
- if (!CurLoop->contains (cast<Instruction>(U)))
1064
- if (auto *PN = dyn_cast<PHINode>(U); !PN || PN->getParent () != ExitSucc)
1064
+ if (!CurLoop->contains (cast<Instruction>(U))) {
1065
+ auto *PN = dyn_cast<PHINode>(U);
1066
+ if (!PN || PN->getParent () != ExitSucc)
1065
1067
return false ;
1068
+ }
1066
1069
1067
1070
// Match the loads and check they are simple.
1068
- Value *A, *B;
1069
- if (!match (LoadA, m_Load (m_Value (A))) || !cast<LoadInst>(LoadA)->isSimple () ||
1070
- !match (LoadB, m_Load (m_Value (B))) || !cast<LoadInst>(LoadB)->isSimple ())
1071
+ Value *Search, *Needle;
1072
+ if (!match (LoadSearch, m_Load (m_Value (Search))) ||
1073
+ !match (LoadNeedle, m_Load (m_Value (Needle))) ||
1074
+ !cast<LoadInst>(LoadSearch)->isSimple () ||
1075
+ !cast<LoadInst>(LoadNeedle)->isSimple ())
1071
1076
return false ;
1072
1077
1073
1078
// Check we are loading valid characters.
1074
- Type *CharTy = LoadA ->getType ();
1075
- if (!CharTy->isIntegerTy () || LoadB ->getType () != CharTy)
1079
+ Type *CharTy = LoadSearch ->getType ();
1080
+ if (!CharTy->isIntegerTy () || LoadNeedle ->getType () != CharTy)
1076
1081
return false ;
1077
1082
1078
1083
// Pick the vectorisation factor based on CharTy, work out the cost of the
@@ -1088,40 +1093,40 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
1088
1093
return false ;
1089
1094
1090
1095
// The loads come from two PHIs, each with two incoming values.
1091
- PHINode *PNA = dyn_cast<PHINode>(A );
1092
- PHINode *PNB = dyn_cast<PHINode>(B );
1093
- if (!PNA || PNA ->getNumIncomingValues () != 2 || !PNB ||
1094
- PNB ->getNumIncomingValues () != 2 )
1096
+ PHINode *PSearch = dyn_cast<PHINode>(Search );
1097
+ PHINode *PNeedle = dyn_cast<PHINode>(Needle );
1098
+ if (!PSearch || PSearch ->getNumIncomingValues () != 2 || !PNeedle ||
1099
+ PNeedle ->getNumIncomingValues () != 2 )
1095
1100
return false ;
1096
1101
1097
- // One PHI comes from the outer loop (PNA ), the other one from the inner loop
1098
- // (PNB ). PNA effectively corresponds to IndPhi.
1099
- if (InnerLoop->contains (PNA ))
1100
- std::swap (PNA, PNB );
1101
- if (PNA != &Header->front () || PNB != &MatchBB->front ())
1102
+ // One PHI comes from the outer loop (PSearch ), the other one from the inner
1103
+ // loop (PNeedle ). PSearch effectively corresponds to IndPhi.
1104
+ if (InnerLoop->contains (PSearch ))
1105
+ std::swap (PSearch, PNeedle );
1106
+ if (PSearch != &Header->front () || PNeedle != &MatchBB->front ())
1102
1107
return false ;
1103
1108
1104
1109
// The incoming values of both PHI nodes should be a gep of 1.
1105
- Value *StartA = PNA ->getIncomingValue (0 );
1106
- Value *IndexA = PNA ->getIncomingValue (1 );
1107
- if (CurLoop->contains (PNA ->getIncomingBlock (0 )))
1108
- std::swap (StartA, IndexA );
1110
+ Value *SearchStart = PSearch ->getIncomingValue (0 );
1111
+ Value *SearchIndex = PSearch ->getIncomingValue (1 );
1112
+ if (CurLoop->contains (PSearch ->getIncomingBlock (0 )))
1113
+ std::swap (SearchStart, SearchIndex );
1109
1114
1110
- Value *StartB = PNB ->getIncomingValue (0 );
1111
- Value *IndexB = PNB ->getIncomingValue (1 );
1112
- if (InnerLoop->contains (PNB ->getIncomingBlock (0 )))
1113
- std::swap (StartB, IndexB );
1115
+ Value *NeedleStart = PNeedle ->getIncomingValue (0 );
1116
+ Value *NeedleIndex = PNeedle ->getIncomingValue (1 );
1117
+ if (InnerLoop->contains (PNeedle ->getIncomingBlock (0 )))
1118
+ std::swap (NeedleStart, NeedleIndex );
1114
1119
1115
1120
// Match the GEPs.
1116
- if (!match (IndexA , m_GEP (m_Specific (PNA ), m_One ())) ||
1117
- !match (IndexB , m_GEP (m_Specific (PNB ), m_One ())))
1121
+ if (!match (SearchIndex , m_GEP (m_Specific (PSearch ), m_One ())) ||
1122
+ !match (NeedleIndex , m_GEP (m_Specific (PNeedle ), m_One ())))
1118
1123
return false ;
1119
1124
1120
1125
// Check the GEPs result type matches `CharTy'.
1121
- GetElementPtrInst *GEPA = cast<GetElementPtrInst>(IndexA );
1122
- GetElementPtrInst *GEPB = cast<GetElementPtrInst>(IndexB );
1123
- if (GEPA ->getResultElementType () != CharTy ||
1124
- GEPB ->getResultElementType () != CharTy)
1126
+ GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex );
1127
+ GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex );
1128
+ if (GEPSearch ->getResultElementType () != CharTy ||
1129
+ GEPNeedle ->getResultElementType () != CharTy)
1125
1130
return false ;
1126
1131
1127
1132
// InnerBB should increment the address of the needle pointer.
@@ -1131,11 +1136,12 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
1131
1136
// %18 = icmp eq ptr %17, %10
1132
1137
// br i1 %18, label %OuterBB, label %MatchBB
1133
1138
BasicBlock *OuterBB;
1134
- Value *EndB ;
1139
+ Value *NeedleEnd ;
1135
1140
if (!match (InnerBB->getTerminator (),
1136
- m_Br (m_ICmp (MatchPred, m_Specific (GEPB), m_Value (EndB)),
1141
+ m_Br (m_SpecificICmp (ICmpInst::ICMP_EQ, m_Specific (GEPNeedle),
1142
+ m_Value (NeedleEnd)),
1137
1143
m_BasicBlock (OuterBB), m_Specific (MatchBB))) ||
1138
- MatchPred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains (OuterBB))
1144
+ !CurLoop->contains (OuterBB))
1139
1145
return false ;
1140
1146
1141
1147
// OuterBB should increment the address of the search element pointer.
@@ -1145,17 +1151,17 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
1145
1151
// %25 = icmp eq ptr %24, %6
1146
1152
// br i1 %25, label %ExitFail, label %Header
1147
1153
BasicBlock *ExitFail;
1148
- Value *EndA ;
1154
+ Value *SearchEnd ;
1149
1155
if (!match (OuterBB->getTerminator (),
1150
- m_Br (m_ICmp (MatchPred , m_Specific (GEPA), m_Value (EndA) ),
1151
- m_BasicBlock (ExitFail), m_Specific (Header))) ||
1152
- MatchPred != ICmpInst::Predicate::ICMP_EQ )
1156
+ m_Br (m_SpecificICmp (ICmpInst::ICMP_EQ , m_Specific (GEPSearch ),
1157
+ m_Value (SearchEnd)),
1158
+ m_BasicBlock (ExitFail), m_Specific (Header))) )
1153
1159
return false ;
1154
1160
1155
1161
LLVM_DEBUG (dbgs () << " Found idiom in loop: \n " << *CurLoop << " \n\n " );
1156
1162
1157
- transformFindFirstByte (IndPhi, VF, CharTy, ExitSucc, ExitFail, StartA, EndA ,
1158
- StartB, EndB );
1163
+ transformFindFirstByte (IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart ,
1164
+ SearchEnd, NeedleStart, NeedleEnd );
1159
1165
return true ;
1160
1166
}
1161
1167
@@ -1187,6 +1193,8 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1187
1193
// (I) Inner loop where we iterate over the elements of the needle array.
1188
1194
//
1189
1195
// Overall, the blocks do the following:
1196
+ // (0) Check if the arrays can't cross page boundaries. If so go to (1),
1197
+ // otherwise fall back to the original scalar loop.
1190
1198
// (1) Load the search array. Go to (2).
1191
1199
// (2) (a) Load the needle array.
1192
1200
// (b) Splat the first element to the inactive lanes.
@@ -1196,8 +1204,9 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1196
1204
// (2), otherwise go to (5).
1197
1205
// (5) Check if we've reached the end of the search array. If not loop back to
1198
1206
// (1), otherwise exit.
1199
- // Block (3) is not part of any loop. Blocks (1,5) and (2,4) belong to the
1200
- // outer and inner loops, respectively.
1207
+ // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
1208
+ // the outer and inner loops, respectively.
1209
+ BasicBlock *BB0 = BasicBlock::Create (Ctx, " " , SPH->getParent (), SPH);
1201
1210
BasicBlock *BB1 = BasicBlock::Create (Ctx, " " , SPH->getParent (), SPH);
1202
1211
BasicBlock *BB2 = BasicBlock::Create (Ctx, " " , SPH->getParent (), SPH);
1203
1212
BasicBlock *BB3 = BasicBlock::Create (Ctx, " " , SPH->getParent (), SPH);
@@ -1209,6 +1218,7 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1209
1218
auto InnerLoop = LI->AllocateLoop ();
1210
1219
1211
1220
if (auto ParentLoop = CurLoop->getParentLoop ()) {
1221
+ ParentLoop->addBasicBlockToLoop (BB0, *LI);
1212
1222
ParentLoop->addChildLoop (OuterLoop);
1213
1223
ParentLoop->addBasicBlockToLoop (BB3, *LI);
1214
1224
} else {
@@ -1224,24 +1234,46 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1224
1234
InnerLoop->addBasicBlockToLoop (BB2, *LI);
1225
1235
InnerLoop->addBasicBlockToLoop (BB4, *LI);
1226
1236
1227
- // Set a reference to the old scalar loop and create a predicate of VF
1228
- // elements.
1229
- Builder.SetInsertPoint (Preheader->getTerminator ());
1230
- Value *Pred16 =
1237
+ // Update the terminator added by SplitBlock to branch to the first block.
1238
+ Preheader->getTerminator ()->setSuccessor (0 , BB0);
1239
+ DTU.applyUpdates ({{DominatorTree::Delete, Preheader, SPH},
1240
+ {DominatorTree::Insert, Preheader, BB0}});
1241
+
1242
+ // (0) Check if we could be crossing a page boundary; if so, fallback to the
1243
+ // old scalar loops. Also create a predicate of VF elements to be used in the
1244
+ // vector loops.
1245
+ Builder.SetInsertPoint (BB0);
1246
+ Value *ISearchStart = Builder.CreatePtrToInt (SearchStart, I64Ty);
1247
+ Value *ISearchEnd = Builder.CreatePtrToInt (SearchEnd, I64Ty);
1248
+ Value *INeedleStart = Builder.CreatePtrToInt (NeedleStart, I64Ty);
1249
+ Value *INeedleEnd = Builder.CreatePtrToInt (NeedleEnd, I64Ty);
1250
+ Value *PredVF =
1231
1251
Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1232
1252
{ConstantInt::get (I64Ty, 0 ), ConstVF});
1233
- Builder.CreateCondBr (Builder.getFalse (), SPH, BB1);
1234
- Preheader->getTerminator ()->eraseFromParent ();
1235
- DTU.applyUpdates ({{DominatorTree::Insert, Preheader, BB1}});
1253
+
1254
+ const uint64_t MinPageSize = TTI->getMinPageSize ().value ();
1255
+ const uint64_t AddrShiftAmt = llvm::Log2_64 (MinPageSize);
1256
+ Value *SearchStartPage = Builder.CreateLShr (ISearchStart, AddrShiftAmt);
1257
+ Value *SearchEndPage = Builder.CreateLShr (ISearchEnd, AddrShiftAmt);
1258
+ Value *NeedleStartPage = Builder.CreateLShr (INeedleStart, AddrShiftAmt);
1259
+ Value *NeedleEndPage = Builder.CreateLShr (INeedleEnd, AddrShiftAmt);
1260
+ Value *SearchPageCmp = Builder.CreateICmpNE (SearchStartPage, SearchEndPage);
1261
+ Value *NeedlePageCmp = Builder.CreateICmpNE (NeedleStartPage, NeedleEndPage);
1262
+
1263
+ Value *CombinedPageCmp = Builder.CreateOr (SearchPageCmp, NeedlePageCmp);
1264
+ BranchInst *CombinedPageBr = Builder.CreateCondBr (CombinedPageCmp, SPH, BB1);
1265
+ CombinedPageBr->setMetadata (LLVMContext::MD_prof,
1266
+ MDBuilder (Ctx).createBranchWeights (10 , 90 ));
1267
+ DTU.applyUpdates (
1268
+ {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
1236
1269
1237
1270
// (1) Load the search array and branch to the inner loop.
1238
1271
Builder.SetInsertPoint (BB1);
1239
1272
PHINode *Search = Builder.CreatePHI (PtrTy, 2 , " psearch" );
1240
- Value *PredSearch =
1241
- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1242
- {Builder.CreatePointerCast (Search, I64Ty),
1243
- Builder.CreatePointerCast (SearchEnd, I64Ty)});
1244
- PredSearch = Builder.CreateAnd (Pred16, PredSearch);
1273
+ Value *PredSearch = Builder.CreateIntrinsic (
1274
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1275
+ {Builder.CreatePtrToInt (Search, I64Ty), ISearchEnd});
1276
+ PredSearch = Builder.CreateAnd (PredVF, PredSearch);
1245
1277
Value *LoadSearch =
1246
1278
Builder.CreateMaskedLoad (CharVTy, Search, Align (1 ), PredSearch, Passthru);
1247
1279
Builder.CreateBr (BB2);
@@ -1252,11 +1284,10 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1252
1284
PHINode *Needle = Builder.CreatePHI (PtrTy, 2 , " pneedle" );
1253
1285
1254
1286
// (2.a) Load the needle array.
1255
- Value *PredNeedle =
1256
- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1257
- {Builder.CreatePointerCast (Needle, I64Ty),
1258
- Builder.CreatePointerCast (NeedleEnd, I64Ty)});
1259
- PredNeedle = Builder.CreateAnd (Pred16, PredNeedle);
1287
+ Value *PredNeedle = Builder.CreateIntrinsic (
1288
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1289
+ {Builder.CreatePtrToInt (Needle, I64Ty), INeedleEnd});
1290
+ PredNeedle = Builder.CreateAnd (PredVF, PredNeedle);
1260
1291
Value *LoadNeedle =
1261
1292
Builder.CreateMaskedLoad (CharVTy, Needle, Align (1 ), PredNeedle, Passthru);
1262
1293
@@ -1279,10 +1310,12 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1279
1310
1280
1311
// (3) We found a match. Compute the index of its location and exit.
1281
1312
Builder.SetInsertPoint (BB3);
1313
+ PHINode *MatchLCSSA = Builder.CreatePHI (PtrTy, 1 );
1314
+ PHINode *MatchPredLCSSA = Builder.CreatePHI (MatchPred->getType (), 1 );
1282
1315
Value *MatchCnt = Builder.CreateIntrinsic (
1283
1316
Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType ()},
1284
- {MatchPred , /* ZeroIsPoison=*/ Builder.getInt1 (true )});
1285
- Value *MatchVal = Builder.CreateGEP (CharTy, Search , MatchCnt);
1317
+ {MatchPredLCSSA , /* ZeroIsPoison=*/ Builder.getInt1 (true )});
1318
+ Value *MatchVal = Builder.CreateGEP (CharTy, MatchLCSSA , MatchCnt);
1286
1319
Builder.CreateBr (ExitSucc);
1287
1320
DTU.applyUpdates ({{DominatorTree::Insert, BB3, ExitSucc}});
1288
1321
@@ -1301,11 +1334,14 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
1301
1334
DTU.applyUpdates ({{DominatorTree::Insert, BB5, BB1},
1302
1335
{DominatorTree::Insert, BB5, ExitFail}});
1303
1336
1304
- // Set up the PHI's .
1305
- Search->addIncoming (SearchStart, Preheader );
1337
+ // Set up the PHI nodes .
1338
+ Search->addIncoming (SearchStart, BB0 );
1306
1339
Search->addIncoming (NextSearch, BB5);
1307
1340
Needle->addIncoming (NeedleStart, BB1);
1308
1341
Needle->addIncoming (NextNeedle, BB4);
1342
+ // These are needed to retain LCSSA form.
1343
+ MatchLCSSA->addIncoming (Search, BB2);
1344
+ MatchPredLCSSA->addIncoming (MatchPred, BB2);
1309
1345
1310
1346
if (VerifyLoops) {
1311
1347
OuterLoop->verifyLoop ();
@@ -1332,11 +1368,16 @@ void LoopIdiomVectorize::transformFindFirstByte(
1332
1368
expandFindFirstByte (Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
1333
1369
SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1334
1370
1371
+ assert (PHBranch->isUnconditional () &&
1372
+ " Expected preheader to terminate with an unconditional branch." );
1373
+
1335
1374
// Add new incoming values with the result of the transformation to PHINodes
1336
1375
// of ExitSucc that use IndPhi.
1337
- for (auto *U : llvm::make_early_inc_range (IndPhi->users ()))
1338
- if (auto *PN = dyn_cast<PHINode>(U); PN && PN->getParent () == ExitSucc)
1376
+ for (auto *U : llvm::make_early_inc_range (IndPhi->users ())) {
1377
+ auto *PN = dyn_cast<PHINode>(U);
1378
+ if (PN && PN->getParent () == ExitSucc)
1339
1379
PN->addIncoming (MatchVal, cast<Instruction>(MatchVal)->getParent ());
1380
+ }
1340
1381
1341
1382
if (VerifyLoops && CurLoop->getParentLoop ()) {
1342
1383
CurLoop->getParentLoop ()->verifyLoop ();
0 commit comments