Skip to content

Commit ca1a373

Browse files
committed
Add page boundary checks and address other comments
1 parent d7b6b9d commit ca1a373

File tree

2 files changed

+261
-191
lines changed

2 files changed

+261
-191
lines changed

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 112 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -992,8 +992,10 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
992992
bool LoopIdiomVectorize::recognizeFindFirstByte() {
993993
// Currently the transformation only works on scalable vector types, although
994994
// there is no fundamental reason why it cannot be made to work for fixed
995-
// vectors too.
996-
if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
995+
// vectors. We also need to know the target's minimum page size in order to
996+
// generate runtime memory checks to ensure the vector version won't fault.
997+
if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
998+
DisableFindFirstByte)
997999
return false;
9981000

9991001
// Define some constants we need throughout.
@@ -1049,30 +1051,33 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
10491051
// %22 = icmp eq i8 %15, %21
10501052
// br i1 %22, label %ExitSucc, label %InnerBB
10511053
BasicBlock *ExitSucc, *InnerBB;
1052-
Value *LoadA, *LoadB;
1053-
ICmpInst::Predicate MatchPred;
1054+
Value *LoadSearch, *LoadNeedle;
1055+
CmpPredicate MatchPred;
10541056
if (!match(MatchBB->getTerminator(),
1055-
m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
1057+
m_Br(m_ICmp(MatchPred, m_Value(LoadSearch), m_Value(LoadNeedle)),
10561058
m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
1057-
MatchPred != ICmpInst::Predicate::ICMP_EQ ||
1058-
!InnerLoop->contains(InnerBB))
1059+
MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(InnerBB))
10591060
return false;
10601061

10611062
// We expect outside uses of `IndPhi' in ExitSucc (and only there).
10621063
for (User *U : IndPhi->users())
1063-
if (!CurLoop->contains(cast<Instruction>(U)))
1064-
if (auto *PN = dyn_cast<PHINode>(U); !PN || PN->getParent() != ExitSucc)
1064+
if (!CurLoop->contains(cast<Instruction>(U))) {
1065+
auto *PN = dyn_cast<PHINode>(U);
1066+
if (!PN || PN->getParent() != ExitSucc)
10651067
return false;
1068+
}
10661069

10671070
// Match the loads and check they are simple.
1068-
Value *A, *B;
1069-
if (!match(LoadA, m_Load(m_Value(A))) || !cast<LoadInst>(LoadA)->isSimple() ||
1070-
!match(LoadB, m_Load(m_Value(B))) || !cast<LoadInst>(LoadB)->isSimple())
1071+
Value *Search, *Needle;
1072+
if (!match(LoadSearch, m_Load(m_Value(Search))) ||
1073+
!match(LoadNeedle, m_Load(m_Value(Needle))) ||
1074+
!cast<LoadInst>(LoadSearch)->isSimple() ||
1075+
!cast<LoadInst>(LoadNeedle)->isSimple())
10711076
return false;
10721077

10731078
// Check we are loading valid characters.
1074-
Type *CharTy = LoadA->getType();
1075-
if (!CharTy->isIntegerTy() || LoadB->getType() != CharTy)
1079+
Type *CharTy = LoadSearch->getType();
1080+
if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
10761081
return false;
10771082

10781083
// Pick the vectorisation factor based on CharTy, work out the cost of the
@@ -1088,40 +1093,40 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
10881093
return false;
10891094

10901095
// The loads come from two PHIs, each with two incoming values.
1091-
PHINode *PNA = dyn_cast<PHINode>(A);
1092-
PHINode *PNB = dyn_cast<PHINode>(B);
1093-
if (!PNA || PNA->getNumIncomingValues() != 2 || !PNB ||
1094-
PNB->getNumIncomingValues() != 2)
1096+
PHINode *PSearch = dyn_cast<PHINode>(Search);
1097+
PHINode *PNeedle = dyn_cast<PHINode>(Needle);
1098+
if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
1099+
PNeedle->getNumIncomingValues() != 2)
10951100
return false;
10961101

1097-
// One PHI comes from the outer loop (PNA), the other one from the inner loop
1098-
// (PNB). PNA effectively corresponds to IndPhi.
1099-
if (InnerLoop->contains(PNA))
1100-
std::swap(PNA, PNB);
1101-
if (PNA != &Header->front() || PNB != &MatchBB->front())
1102+
// One PHI comes from the outer loop (PSearch), the other one from the inner
1103+
// loop (PNeedle). PSearch effectively corresponds to IndPhi.
1104+
if (InnerLoop->contains(PSearch))
1105+
std::swap(PSearch, PNeedle);
1106+
if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
11021107
return false;
11031108

11041109
// The incoming values of both PHI nodes should be a gep of 1.
1105-
Value *StartA = PNA->getIncomingValue(0);
1106-
Value *IndexA = PNA->getIncomingValue(1);
1107-
if (CurLoop->contains(PNA->getIncomingBlock(0)))
1108-
std::swap(StartA, IndexA);
1110+
Value *SearchStart = PSearch->getIncomingValue(0);
1111+
Value *SearchIndex = PSearch->getIncomingValue(1);
1112+
if (CurLoop->contains(PSearch->getIncomingBlock(0)))
1113+
std::swap(SearchStart, SearchIndex);
11091114

1110-
Value *StartB = PNB->getIncomingValue(0);
1111-
Value *IndexB = PNB->getIncomingValue(1);
1112-
if (InnerLoop->contains(PNB->getIncomingBlock(0)))
1113-
std::swap(StartB, IndexB);
1115+
Value *NeedleStart = PNeedle->getIncomingValue(0);
1116+
Value *NeedleIndex = PNeedle->getIncomingValue(1);
1117+
if (InnerLoop->contains(PNeedle->getIncomingBlock(0)))
1118+
std::swap(NeedleStart, NeedleIndex);
11141119

11151120
// Match the GEPs.
1116-
if (!match(IndexA, m_GEP(m_Specific(PNA), m_One())) ||
1117-
!match(IndexB, m_GEP(m_Specific(PNB), m_One())))
1121+
if (!match(SearchIndex, m_GEP(m_Specific(PSearch), m_One())) ||
1122+
!match(NeedleIndex, m_GEP(m_Specific(PNeedle), m_One())))
11181123
return false;
11191124

11201125
// Check the GEPs result type matches `CharTy'.
1121-
GetElementPtrInst *GEPA = cast<GetElementPtrInst>(IndexA);
1122-
GetElementPtrInst *GEPB = cast<GetElementPtrInst>(IndexB);
1123-
if (GEPA->getResultElementType() != CharTy ||
1124-
GEPB->getResultElementType() != CharTy)
1126+
GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex);
1127+
GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex);
1128+
if (GEPSearch->getResultElementType() != CharTy ||
1129+
GEPNeedle->getResultElementType() != CharTy)
11251130
return false;
11261131

11271132
// InnerBB should increment the address of the needle pointer.
@@ -1131,11 +1136,12 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
11311136
// %18 = icmp eq ptr %17, %10
11321137
// br i1 %18, label %OuterBB, label %MatchBB
11331138
BasicBlock *OuterBB;
1134-
Value *EndB;
1139+
Value *NeedleEnd;
11351140
if (!match(InnerBB->getTerminator(),
1136-
m_Br(m_ICmp(MatchPred, m_Specific(GEPB), m_Value(EndB)),
1141+
m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPNeedle),
1142+
m_Value(NeedleEnd)),
11371143
m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
1138-
MatchPred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(OuterBB))
1144+
!CurLoop->contains(OuterBB))
11391145
return false;
11401146

11411147
// OuterBB should increment the address of the search element pointer.
@@ -1145,17 +1151,17 @@ bool LoopIdiomVectorize::recognizeFindFirstByte() {
11451151
// %25 = icmp eq ptr %24, %6
11461152
// br i1 %25, label %ExitFail, label %Header
11471153
BasicBlock *ExitFail;
1148-
Value *EndA;
1154+
Value *SearchEnd;
11491155
if (!match(OuterBB->getTerminator(),
1150-
m_Br(m_ICmp(MatchPred, m_Specific(GEPA), m_Value(EndA)),
1151-
m_BasicBlock(ExitFail), m_Specific(Header))) ||
1152-
MatchPred != ICmpInst::Predicate::ICMP_EQ)
1156+
m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPSearch),
1157+
m_Value(SearchEnd)),
1158+
m_BasicBlock(ExitFail), m_Specific(Header))))
11531159
return false;
11541160

11551161
LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
11561162

1157-
transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, StartA, EndA,
1158-
StartB, EndB);
1163+
transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
1164+
SearchEnd, NeedleStart, NeedleEnd);
11591165
return true;
11601166
}
11611167

@@ -1187,6 +1193,8 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
11871193
// (I) Inner loop where we iterate over the elements of the needle array.
11881194
//
11891195
// Overall, the blocks do the following:
1196+
// (0) Check if the arrays can't cross page boundaries. If so go to (1),
1197+
// otherwise fall back to the original scalar loop.
11901198
// (1) Load the search array. Go to (2).
11911199
// (2) (a) Load the needle array.
11921200
// (b) Splat the first element to the inactive lanes.
@@ -1196,8 +1204,9 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
11961204
// (2), otherwise go to (5).
11971205
// (5) Check if we've reached the end of the search array. If not loop back to
11981206
// (1), otherwise exit.
1199-
// Block (3) is not part of any loop. Blocks (1,5) and (2,4) belong to the
1200-
// outer and inner loops, respectively.
1207+
// Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
1208+
// the outer and inner loops, respectively.
1209+
BasicBlock *BB0 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
12011210
BasicBlock *BB1 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
12021211
BasicBlock *BB2 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
12031212
BasicBlock *BB3 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
@@ -1209,6 +1218,7 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
12091218
auto InnerLoop = LI->AllocateLoop();
12101219

12111220
if (auto ParentLoop = CurLoop->getParentLoop()) {
1221+
ParentLoop->addBasicBlockToLoop(BB0, *LI);
12121222
ParentLoop->addChildLoop(OuterLoop);
12131223
ParentLoop->addBasicBlockToLoop(BB3, *LI);
12141224
} else {
@@ -1224,24 +1234,46 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
12241234
InnerLoop->addBasicBlockToLoop(BB2, *LI);
12251235
InnerLoop->addBasicBlockToLoop(BB4, *LI);
12261236

1227-
// Set a reference to the old scalar loop and create a predicate of VF
1228-
// elements.
1229-
Builder.SetInsertPoint(Preheader->getTerminator());
1230-
Value *Pred16 =
1237+
// Update the terminator added by SplitBlock to branch to the first block.
1238+
Preheader->getTerminator()->setSuccessor(0, BB0);
1239+
DTU.applyUpdates({{DominatorTree::Delete, Preheader, SPH},
1240+
{DominatorTree::Insert, Preheader, BB0}});
1241+
1242+
// (0) Check if we could be crossing a page boundary; if so, fallback to the
1243+
// old scalar loops. Also create a predicate of VF elements to be used in the
1244+
// vector loops.
1245+
Builder.SetInsertPoint(BB0);
1246+
Value *ISearchStart = Builder.CreatePtrToInt(SearchStart, I64Ty);
1247+
Value *ISearchEnd = Builder.CreatePtrToInt(SearchEnd, I64Ty);
1248+
Value *INeedleStart = Builder.CreatePtrToInt(NeedleStart, I64Ty);
1249+
Value *INeedleEnd = Builder.CreatePtrToInt(NeedleEnd, I64Ty);
1250+
Value *PredVF =
12311251
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
12321252
{ConstantInt::get(I64Ty, 0), ConstVF});
1233-
Builder.CreateCondBr(Builder.getFalse(), SPH, BB1);
1234-
Preheader->getTerminator()->eraseFromParent();
1235-
DTU.applyUpdates({{DominatorTree::Insert, Preheader, BB1}});
1253+
1254+
const uint64_t MinPageSize = TTI->getMinPageSize().value();
1255+
const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
1256+
Value *SearchStartPage = Builder.CreateLShr(ISearchStart, AddrShiftAmt);
1257+
Value *SearchEndPage = Builder.CreateLShr(ISearchEnd, AddrShiftAmt);
1258+
Value *NeedleStartPage = Builder.CreateLShr(INeedleStart, AddrShiftAmt);
1259+
Value *NeedleEndPage = Builder.CreateLShr(INeedleEnd, AddrShiftAmt);
1260+
Value *SearchPageCmp = Builder.CreateICmpNE(SearchStartPage, SearchEndPage);
1261+
Value *NeedlePageCmp = Builder.CreateICmpNE(NeedleStartPage, NeedleEndPage);
1262+
1263+
Value *CombinedPageCmp = Builder.CreateOr(SearchPageCmp, NeedlePageCmp);
1264+
BranchInst *CombinedPageBr = Builder.CreateCondBr(CombinedPageCmp, SPH, BB1);
1265+
CombinedPageBr->setMetadata(LLVMContext::MD_prof,
1266+
MDBuilder(Ctx).createBranchWeights(10, 90));
1267+
DTU.applyUpdates(
1268+
{{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
12361269

12371270
// (1) Load the search array and branch to the inner loop.
12381271
Builder.SetInsertPoint(BB1);
12391272
PHINode *Search = Builder.CreatePHI(PtrTy, 2, "psearch");
1240-
Value *PredSearch =
1241-
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1242-
{Builder.CreatePointerCast(Search, I64Ty),
1243-
Builder.CreatePointerCast(SearchEnd, I64Ty)});
1244-
PredSearch = Builder.CreateAnd(Pred16, PredSearch);
1273+
Value *PredSearch = Builder.CreateIntrinsic(
1274+
Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1275+
{Builder.CreatePtrToInt(Search, I64Ty), ISearchEnd});
1276+
PredSearch = Builder.CreateAnd(PredVF, PredSearch);
12451277
Value *LoadSearch =
12461278
Builder.CreateMaskedLoad(CharVTy, Search, Align(1), PredSearch, Passthru);
12471279
Builder.CreateBr(BB2);
@@ -1252,11 +1284,10 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
12521284
PHINode *Needle = Builder.CreatePHI(PtrTy, 2, "pneedle");
12531285

12541286
// (2.a) Load the needle array.
1255-
Value *PredNeedle =
1256-
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1257-
{Builder.CreatePointerCast(Needle, I64Ty),
1258-
Builder.CreatePointerCast(NeedleEnd, I64Ty)});
1259-
PredNeedle = Builder.CreateAnd(Pred16, PredNeedle);
1287+
Value *PredNeedle = Builder.CreateIntrinsic(
1288+
Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1289+
{Builder.CreatePtrToInt(Needle, I64Ty), INeedleEnd});
1290+
PredNeedle = Builder.CreateAnd(PredVF, PredNeedle);
12601291
Value *LoadNeedle =
12611292
Builder.CreateMaskedLoad(CharVTy, Needle, Align(1), PredNeedle, Passthru);
12621293

@@ -1279,10 +1310,12 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
12791310

12801311
// (3) We found a match. Compute the index of its location and exit.
12811312
Builder.SetInsertPoint(BB3);
1313+
PHINode *MatchLCSSA = Builder.CreatePHI(PtrTy, 1);
1314+
PHINode *MatchPredLCSSA = Builder.CreatePHI(MatchPred->getType(), 1);
12821315
Value *MatchCnt = Builder.CreateIntrinsic(
12831316
Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType()},
1284-
{MatchPred, /*ZeroIsPoison=*/Builder.getInt1(true)});
1285-
Value *MatchVal = Builder.CreateGEP(CharTy, Search, MatchCnt);
1317+
{MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(true)});
1318+
Value *MatchVal = Builder.CreateGEP(CharTy, MatchLCSSA, MatchCnt);
12861319
Builder.CreateBr(ExitSucc);
12871320
DTU.applyUpdates({{DominatorTree::Insert, BB3, ExitSucc}});
12881321

@@ -1301,11 +1334,14 @@ Value *LoopIdiomVectorize::expandFindFirstByte(
13011334
DTU.applyUpdates({{DominatorTree::Insert, BB5, BB1},
13021335
{DominatorTree::Insert, BB5, ExitFail}});
13031336

1304-
// Set up the PHI's.
1305-
Search->addIncoming(SearchStart, Preheader);
1337+
// Set up the PHI nodes.
1338+
Search->addIncoming(SearchStart, BB0);
13061339
Search->addIncoming(NextSearch, BB5);
13071340
Needle->addIncoming(NeedleStart, BB1);
13081341
Needle->addIncoming(NextNeedle, BB4);
1342+
// These are needed to retain LCSSA form.
1343+
MatchLCSSA->addIncoming(Search, BB2);
1344+
MatchPredLCSSA->addIncoming(MatchPred, BB2);
13091345

13101346
if (VerifyLoops) {
13111347
OuterLoop->verifyLoop();
@@ -1332,11 +1368,16 @@ void LoopIdiomVectorize::transformFindFirstByte(
13321368
expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
13331369
SearchStart, SearchEnd, NeedleStart, NeedleEnd);
13341370

1371+
assert(PHBranch->isUnconditional() &&
1372+
"Expected preheader to terminate with an unconditional branch.");
1373+
13351374
// Add new incoming values with the result of the transformation to PHINodes
13361375
// of ExitSucc that use IndPhi.
1337-
for (auto *U : llvm::make_early_inc_range(IndPhi->users()))
1338-
if (auto *PN = dyn_cast<PHINode>(U); PN && PN->getParent() == ExitSucc)
1376+
for (auto *U : llvm::make_early_inc_range(IndPhi->users())) {
1377+
auto *PN = dyn_cast<PHINode>(U);
1378+
if (PN && PN->getParent() == ExitSucc)
13391379
PN->addIncoming(MatchVal, cast<Instruction>(MatchVal)->getParent());
1380+
}
13401381

13411382
if (VerifyLoops && CurLoop->getParentLoop()) {
13421383
CurLoop->getParentLoop()->verifyLoop();

0 commit comments

Comments
 (0)