Skip to content

Commit d5769df

Browse files
committed
[AggressiveInstCombine] Inline strcmp/strncmp
* isOnlyUsedInZeroComparison * more tests * ...
1 parent e2d3521 commit d5769df

File tree

6 files changed

+313
-319
lines changed

6 files changed

+313
-319
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ bool isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
116116
const DominatorTree *DT = nullptr,
117117
bool UseInstrInfo = true);
118118

119+
bool isOnlyUsedInZeroComparison(const Instruction *CxtI);
120+
119121
bool isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI);
120122

121123
/// Return true if the given value is known to be non-zero when defined. For

llvm/include/llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
/// \file
99
///
1010
/// AggressiveInstCombiner - Combine expression patterns to form expressions
11-
/// with fewer, simple instructions. This pass does not modify the CFG.
11+
/// with fewer, simple instructions.
1212
///
1313
//===----------------------------------------------------------------------===//
1414

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
252252
RHSCache.getKnownBits(SQ));
253253
}
254254

255+
bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
256+
return !I->user_empty() && all_of(I->users(), [](const User *U) {
257+
ICmpInst::Predicate P;
258+
return match(U, m_ICmp(P, m_Value(), m_Zero()));
259+
});
260+
}
261+
255262
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
256263
return !I->user_empty() && all_of(I->users(), [](const User *U) {
257264
ICmpInst::Predicate P;

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 67 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
7575
m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
7676
m_LShr(m_Value(ShVal1),
7777
m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
78-
return Intrinsic::fshl;
78+
return Intrinsic::fshl;
7979
}
8080

8181
// fshr(ShVal0, ShVal1, ShAmt)
@@ -84,7 +84,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
8484
m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
8585
m_Value(ShAmt))),
8686
m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
87-
return Intrinsic::fshr;
87+
return Intrinsic::fshr;
8888
}
8989

9090
return Intrinsic::not_intrinsic;
@@ -401,21 +401,11 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
401401
/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
402402
/// pessimistic codegen that has to account for setting errno and can enable
403403
/// vectorization.
404-
static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI,
404+
static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
405405
TargetLibraryInfo &TLI, AssumptionCache &AC,
406406
DominatorTree &DT) {
407-
// Match a call to sqrt mathlib function.
408-
auto *Call = dyn_cast<CallInst>(&I);
409-
if (!Call)
410-
return false;
411407

412408
Module *M = Call->getModule();
413-
LibFunc Func;
414-
if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
415-
return false;
416-
417-
if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
418-
return false;
419409

420410
// If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
421411
// (because NNAN or the operand arg must not be less than -0.0) and (2) we
@@ -428,18 +418,18 @@ static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI,
428418
if (TTI.haveFastSqrt(Ty) &&
429419
(Call->hasNoNaNs() ||
430420
cannotBeOrderedLessThanZero(
431-
Arg, 0, SimplifyQuery(M->getDataLayout(), &TLI, &DT, &AC, &I)))) {
432-
IRBuilder<> Builder(&I);
421+
Arg, 0, SimplifyQuery(M->getDataLayout(), &TLI, &DT, &AC, Call)))) {
422+
IRBuilder<> Builder(Call);
433423
IRBuilderBase::FastMathFlagGuard Guard(Builder);
434424
Builder.setFastMathFlags(Call->getFastMathFlags());
435425

436426
Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
437427
Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
438-
I.replaceAllUsesWith(NewSqrt);
428+
Call->replaceAllUsesWith(NewSqrt);
439429

440430
// Explicitly erase the old call because a call with side effects is not
441431
// trivially dead.
442-
I.eraseFromParent();
432+
Call->eraseFromParent();
443433
return true;
444434
}
445435

@@ -932,18 +922,17 @@ static cl::opt<unsigned> StrNCmpInlineThreshold(
932922
namespace {
933923
class StrNCmpInliner {
934924
public:
935-
StrNCmpInliner(CallInst *CI, LibFunc Func, Function::iterator &BBNext,
936-
DomTreeUpdater *DTU, const DataLayout &DL)
937-
: CI(CI), Func(Func), BBNext(BBNext), DTU(DTU), DL(DL) {}
925+
StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
926+
const DataLayout &DL)
927+
: CI(CI), Func(Func), DTU(DTU), DL(DL) {}
938928

939929
bool optimizeStrNCmp();
940930

941931
private:
942-
bool inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Switched);
932+
bool inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
943933

944934
CallInst *CI;
945935
LibFunc Func;
946-
Function::iterator &BBNext;
947936
DomTreeUpdater *DTU;
948937
const DataLayout &DL;
949938
};
@@ -952,7 +941,7 @@ class StrNCmpInliner {
952941

953942
/// First we normalize calls to strncmp/strcmp to the form of
954943
/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
955-
/// (without considering '\0')
944+
/// (without considering '\0').
956945
///
957946
/// Examples:
958947
///
@@ -969,49 +958,53 @@ class StrNCmpInliner {
969958
/// strncmp(s, s2, 3) -> compare(s, s2, 3)
970959
/// \endcode
971960
///
972-
/// We only handle cases that N and exactly one of s1 and s2 are constant. Cases
973-
/// that s1 and s2 are both constant are already handled by the instcombine
974-
/// pass.
961+
/// We only handle cases where N and exactly one of s1 and s2 are constant.
962+
/// Cases that s1 and s2 are both constant are already handled by the
963+
/// instcombine pass.
975964
///
976-
/// We do not handle cases that N > StrNCmpInlineThreshold.
965+
/// We do not handle cases where N > StrNCmpInlineThreshold.
977966
///
978-
/// We also do not handles cases that N < 2, which are already
967+
/// We also do not handles cases where N < 2, which are already
979968
/// handled by the instcombine pass.
980969
///
981970
bool StrNCmpInliner::optimizeStrNCmp() {
982971
if (StrNCmpInlineThreshold < 2)
983972
return false;
984973

974+
if (!isOnlyUsedInZeroComparison(CI))
975+
return false;
976+
985977
Value *Str1P = CI->getArgOperand(0);
986978
Value *Str2P = CI->getArgOperand(1);
987-
// should be handled elsewhere
979+
// Should be handled elsewhere.
988980
if (Str1P == Str2P)
989981
return false;
990982

991983
StringRef Str1, Str2;
992984
bool HasStr1 = getConstantStringInfo(Str1P, Str1, false);
993985
bool HasStr2 = getConstantStringInfo(Str2P, Str2, false);
994-
if (!(HasStr1 ^ HasStr2))
986+
if (HasStr1 == HasStr2)
995987
return false;
996988

997-
// note that '\0' and characters after it are not trimmed
989+
// Note that '\0' and characters after it are not trimmed.
998990
StringRef Str = HasStr1 ? Str1 : Str2;
999991

1000992
size_t Idx = Str.find('\0');
1001993
uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1002994
if (Func == LibFunc_strncmp) {
1003-
if (!isa<ConstantInt>(CI->getArgOperand(2)))
995+
if (auto ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
996+
N = std::min(N, ConstInt->getZExtValue());
997+
else
1004998
return false;
1005-
N = std::min(N, cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue());
1006999
}
1007-
// now N means how many bytes we need to compare at most
1000+
// Now N means how many bytes we need to compare at most.
10081001
if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
10091002
return false;
10101003

10111004
Value *StrP = HasStr1 ? Str2P : Str1P;
10121005

1013-
// cases that StrP has two or more dereferenceable bytes might be better
1014-
// optimized elsewhere
1006+
// Cases where StrP has two or more dereferenceable bytes might be better
1007+
// optimized elsewhere.
10151008
bool CanBeNull = false, CanBeFreed = false;
10161009
if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
10171010
return false;
@@ -1054,7 +1047,7 @@ bool StrNCmpInliner::optimizeStrNCmp() {
10541047
/// BBSubs[N-1] (sub) ---------+
10551048
///
10561049
bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1057-
bool Switched) {
1050+
bool Swapped) {
10581051
auto &Ctx = CI->getContext();
10591052
IRBuilder<> B(Ctx);
10601053

@@ -1076,12 +1069,12 @@ bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
10761069
Value *Base = LHS;
10771070
for (uint64_t i = 0; i < N; ++i) {
10781071
B.SetInsertPoint(BBSubs[i]);
1079-
Value *VL = B.CreateZExt(
1080-
B.CreateLoad(B.getInt8Ty(),
1081-
B.CreateInBoundsGEP(B.getInt8Ty(), Base, B.getInt64(i))),
1082-
CI->getType());
1072+
Value *VL =
1073+
B.CreateZExt(B.CreateLoad(B.getInt8Ty(),
1074+
B.CreateInBoundsPtrAdd(Base, B.getInt64(i))),
1075+
CI->getType());
10831076
Value *VR = ConstantInt::get(CI->getType(), RHS[i]);
1084-
Value *Sub = Switched ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
1077+
Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
10851078
if (i < N - 1)
10861079
B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
10871080
BBNE, BBSubs[i + 1]);
@@ -1094,67 +1087,56 @@ bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
10941087
CI->replaceAllUsesWith(Phi);
10951088
CI->eraseFromParent();
10961089

1097-
BBNext = BBCI->getIterator();
1098-
1099-
// Update DomTree
11001090
if (DTU) {
11011091
SmallVector<DominatorTree::UpdateType, 8> Updates;
1102-
Updates.push_back({DominatorTree::Delete, BBBefore, BBCI});
11031092
Updates.push_back({DominatorTree::Insert, BBBefore, BBSubs[0]});
11041093
for (uint64_t i = 0; i < N; ++i) {
11051094
if (i < N - 1)
11061095
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
11071096
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
11081097
}
11091098
Updates.push_back({DominatorTree::Insert, BBNE, BBCI});
1099+
Updates.push_back({DominatorTree::Delete, BBBefore, BBCI});
11101100
DTU->applyUpdates(Updates);
11111101
}
11121102
return true;
11131103
}
11141104

1115-
static bool inlineLibCalls(Function &F, TargetLibraryInfo &TLI,
1116-
const TargetTransformInfo &TTI, DominatorTree &DT,
1117-
const DataLayout &DL, bool &MadeCFGChange) {
1118-
MadeCFGChange = false;
1119-
DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1120-
1121-
bool MadeChange = false;
1122-
1123-
Function::iterator CurrBB;
1124-
for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) {
1125-
CurrBB = BB++;
1105+
static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1106+
TargetLibraryInfo &TLI, llvm::AssumptionCache &AC,
1107+
DominatorTree &DT, const DataLayout &DL,
1108+
bool &MadeCFGChange) {
11261109

1127-
for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end();
1128-
II != IE; ++II) {
1129-
CallInst *Call = dyn_cast<CallInst>(&*II);
1130-
Function *CalledFunc;
1110+
auto *CI = dyn_cast<CallInst>(&I);
1111+
if (!CI || CI->isNoBuiltin())
1112+
return false;
11311113

1132-
if (!Call || !(CalledFunc = Call->getCalledFunction()))
1133-
continue;
1114+
Function *CalledFunc = CI->getCalledFunction();
1115+
if (!CalledFunc)
1116+
return false;
11341117

1135-
LibFunc LF;
1136-
if (!TLI.getLibFunc(*CalledFunc, LF))
1137-
continue;
1118+
LibFunc LF;
1119+
if (!TLI.getLibFunc(*CalledFunc, LF) ||
1120+
!isLibFuncEmittable(CI->getModule(), &TLI, LF))
1121+
return false;
11381122

1139-
switch (LF) {
1140-
case LibFunc_strcmp:
1141-
case LibFunc_strncmp: {
1142-
if (StrNCmpInliner(Call, LF, BB, &DTU, DL).optimizeStrNCmp()) {
1143-
MadeCFGChange = true;
1144-
break;
1145-
}
1146-
continue;
1147-
}
1148-
default:
1149-
continue;
1150-
}
1123+
DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
11511124

1152-
MadeChange = true;
1153-
break;
1125+
switch (LF) {
1126+
case LibFunc_sqrt:
1127+
case LibFunc_sqrtf:
1128+
case LibFunc_sqrtl:
1129+
return foldSqrt(CI, LF, TTI, TLI, AC, DT);
1130+
case LibFunc_strcmp:
1131+
case LibFunc_strncmp:
1132+
if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1133+
MadeCFGChange = true;
1134+
return true;
11541135
}
1136+
break;
1137+
default:;
11551138
}
1156-
1157-
return MadeChange;
1139+
return false;
11581140
}
11591141

11601142
/// This is the entry point for folds that could be implemented in regular
@@ -1163,7 +1145,7 @@ static bool inlineLibCalls(Function &F, TargetLibraryInfo &TLI,
11631145
static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
11641146
TargetTransformInfo &TTI,
11651147
TargetLibraryInfo &TLI, AliasAnalysis &AA,
1166-
AssumptionCache &AC) {
1148+
AssumptionCache &AC, bool &MadeCFGChange) {
11671149
bool MadeChange = false;
11681150
for (BasicBlock &BB : F) {
11691151
// Ignore unreachable basic blocks.
@@ -1188,7 +1170,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
11881170
// NOTE: This function introduces erasing of the instruction `I`, so it
11891171
// needs to be called at the end of this sequence, otherwise we may make
11901172
// bugs.
1191-
MadeChange |= foldSqrt(I, TTI, TLI, AC, DT);
1173+
MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
11921174
}
11931175
}
11941176

@@ -1209,8 +1191,7 @@ static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
12091191
const DataLayout &DL = F.getParent()->getDataLayout();
12101192
TruncInstCombine TIC(AC, TLI, DL, DT);
12111193
MadeChange |= TIC.run(F);
1212-
MadeChange |= inlineLibCalls(F, TLI, TTI, DT, DL, MadeCFGChange);
1213-
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC);
1194+
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
12141195
return MadeChange;
12151196
}
12161197

0 commit comments

Comments
 (0)