Skip to content

Commit 9726bd4

Browse files
committed
[AggressiveInstCombine] Inline strcmp/strncmp
Inline calls to strcmp(s1, s2) and strncmp(s1, s2, N), where N and exactly one of s1 and s2 are constant.
1 parent 2a4e61b commit 9726bd4

File tree

4 files changed

+603
-222
lines changed

4 files changed

+603
-222
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 255 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/Analysis/AssumptionCache.h"
2020
#include "llvm/Analysis/BasicAliasAnalysis.h"
2121
#include "llvm/Analysis/ConstantFolding.h"
22+
#include "llvm/Analysis/DomTreeUpdater.h"
2223
#include "llvm/Analysis/GlobalsModRef.h"
2324
#include "llvm/Analysis/TargetLibraryInfo.h"
2425
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -28,6 +29,7 @@
2829
#include "llvm/IR/Function.h"
2930
#include "llvm/IR/IRBuilder.h"
3031
#include "llvm/IR/PatternMatch.h"
32+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
3133
#include "llvm/Transforms/Utils/BuildLibCalls.h"
3234
#include "llvm/Transforms/Utils/Local.h"
3335

@@ -922,6 +924,251 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
922924
return true;
923925
}
924926

927+
static cl::opt<unsigned> StrNCmpInlineThreshold(
928+
"strncmp-inline-threshold", cl::init(3), cl::Hidden,
929+
cl::desc("The maximum length of a constant string for a builtin string cmp "
930+
"call eligible for inlining. The default value is 3."));
931+
932+
namespace {
933+
class StrNCmpInliner {
934+
public:
935+
StrNCmpInliner(CallInst *CI, LibFunc Func, Function::iterator &BBNext,
936+
DomTreeUpdater *DTU, const DataLayout &DL)
937+
: CI(CI), Func(Func), BBNext(BBNext), DTU(DTU), DL(DL) {}
938+
939+
bool optimizeStrNCmp();
940+
941+
private:
942+
bool inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Switched);
943+
944+
CallInst *CI;
945+
LibFunc Func;
946+
Function::iterator &BBNext;
947+
DomTreeUpdater *DTU;
948+
const DataLayout &DL;
949+
};
950+
951+
} // namespace
952+
953+
/// First we normalize calls to strncmp/strcmp to the form of
954+
/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
955+
/// (without considering '\0')
956+
///
957+
/// Examples:
958+
///
959+
/// \code
960+
/// strncmp(s, "a", 3) -> compare(s, "a", 2)
961+
/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
962+
/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
963+
/// strcmp(s, "a") -> compare(s, "a", 2)
964+
///
965+
/// char s2[] = {'a'}
966+
/// strncmp(s, s2, 3) -> compare(s, s2, 3)
967+
///
968+
/// char s2[] = {'a', 'b', 'c', 'd'}
969+
/// strncmp(s, s2, 3) -> compare(s, s2, 3)
970+
/// \endcode
971+
///
972+
/// We only handle cases that N and exactly one of s1 and s2 are constant. Cases
973+
/// that s1 and s2 are both constant are already handled by the instcombine
974+
/// pass.
975+
///
976+
/// We do not handle cases that N > StrNCmpInlineThreshold.
977+
///
978+
/// We also do not handles cases that N < 2, which are already
979+
/// handled by the instcombine pass.
980+
///
981+
bool StrNCmpInliner::optimizeStrNCmp() {
982+
if (StrNCmpInlineThreshold < 2)
983+
return false;
984+
985+
Value *Str1P = CI->getArgOperand(0);
986+
Value *Str2P = CI->getArgOperand(1);
987+
// should be handled elsewhere
988+
if (Str1P == Str2P)
989+
return false;
990+
991+
StringRef Str1, Str2;
992+
bool HasStr1 = getConstantStringInfo(Str1P, Str1, false);
993+
bool HasStr2 = getConstantStringInfo(Str2P, Str2, false);
994+
if (!(HasStr1 ^ HasStr2))
995+
return false;
996+
997+
// note that '\0' and characters after it are not trimmed
998+
StringRef Str = HasStr1 ? Str1 : Str2;
999+
1000+
size_t Idx = Str.find('\0');
1001+
uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1002+
if (Func == LibFunc_strncmp) {
1003+
if (!isa<ConstantInt>(CI->getArgOperand(2)))
1004+
return false;
1005+
N = std::min(N, cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue());
1006+
}
1007+
// now N means how many bytes we need to compare at most
1008+
if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1009+
return false;
1010+
1011+
Value *StrP = HasStr1 ? Str2P : Str1P;
1012+
1013+
// cases that StrP has two or more dereferenceable bytes might be better
1014+
// optimized elsewhere
1015+
bool CanBeNull = false, CanBeFreed = false;
1016+
if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1017+
return false;
1018+
1019+
return inlineCompare(StrP, Str, N, HasStr1);
1020+
}
1021+
1022+
/// Convert
1023+
///
1024+
/// \code
1025+
/// ret = compare(s1, s2, N)
1026+
/// \endcode
1027+
///
1028+
/// into
1029+
///
1030+
/// \code
1031+
/// ret = (int)s1[0] - (int)s2[0]
1032+
/// if (ret != 0)
1033+
/// goto NE
1034+
/// ...
1035+
/// ret = (int)s1[N-2] - (int)s2[N-2]
1036+
/// if (ret != 0)
1037+
/// goto NE
1038+
/// ret = (int)s1[N-1] - (int)s2[N-1]
1039+
/// NE:
1040+
/// \endcode
1041+
///
1042+
/// CFG before and after the transformation:
1043+
///
1044+
/// (before)
1045+
/// BBCI
1046+
///
1047+
/// (after)
1048+
/// BBBefore -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBCI
1049+
/// | ^
1050+
/// E |
1051+
/// | |
1052+
/// BBSubs[1] (sub,icmp) --NE-----+
1053+
/// ... |
1054+
/// BBSubs[N-1] (sub) ---------+
1055+
///
1056+
bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1057+
bool Switched) {
1058+
IRBuilder<> B(CI->getContext());
1059+
1060+
BasicBlock *BBCI = CI->getParent();
1061+
bool IsEntry = BBCI->isEntryBlock();
1062+
BasicBlock *BBBefore = splitBlockBefore(BBCI, CI, DTU, nullptr, nullptr,
1063+
BBCI->getName() + ".before");
1064+
1065+
SmallVector<BasicBlock *> BBSubs;
1066+
for (uint64_t i = 0; i < N + 1; ++i)
1067+
BBSubs.push_back(
1068+
BasicBlock::Create(CI->getContext(), "sub", BBCI->getParent(), BBCI));
1069+
BasicBlock *BBNE = BBSubs[N];
1070+
1071+
cast<BranchInst>(BBBefore->getTerminator())->setSuccessor(0, BBSubs[0]);
1072+
1073+
B.SetInsertPoint(BBNE);
1074+
PHINode *Phi = B.CreatePHI(CI->getType(), N);
1075+
B.CreateBr(BBCI);
1076+
1077+
Value *Base = LHS;
1078+
for (uint64_t i = 0; i < N; ++i) {
1079+
B.SetInsertPoint(BBSubs[i]);
1080+
Value *VL = B.CreateZExt(
1081+
B.CreateLoad(B.getInt8Ty(),
1082+
B.CreateInBoundsGEP(B.getInt8Ty(), Base, B.getInt64(i))),
1083+
CI->getType());
1084+
Value *VR = ConstantInt::get(CI->getType(), RHS[i]);
1085+
Value *Sub = Switched ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
1086+
if (i < N - 1)
1087+
B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
1088+
BBNE, BBSubs[i + 1]);
1089+
else
1090+
B.CreateBr(BBNE);
1091+
1092+
Phi->addIncoming(Sub, BBSubs[i]);
1093+
}
1094+
1095+
CI->replaceAllUsesWith(Phi);
1096+
CI->eraseFromParent();
1097+
1098+
BBNext = BBCI->getIterator();
1099+
1100+
// Update DomTree
1101+
if (DTU) {
1102+
if (IsEntry) {
1103+
DTU->recalculate(*BBCI->getParent());
1104+
} else {
1105+
SmallVector<DominatorTree::UpdateType, 8> Updates;
1106+
Updates.push_back({DominatorTree::Delete, BBBefore, BBCI});
1107+
Updates.push_back({DominatorTree::Insert, BBBefore, BBSubs[0]});
1108+
for (uint64_t i = 0; i < N; ++i) {
1109+
if (i < N - 1)
1110+
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1111+
Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
1112+
}
1113+
Updates.push_back({DominatorTree::Insert, BBNE, BBCI});
1114+
DTU->applyUpdates(Updates);
1115+
}
1116+
}
1117+
return true;
1118+
}
1119+
1120+
static bool inlineLibCalls(Function &F, TargetLibraryInfo &TLI,
1121+
const TargetTransformInfo &TTI, DominatorTree &DT,
1122+
bool &MadeCFGChange) {
1123+
MadeCFGChange = false;
1124+
DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1125+
1126+
bool MadeChange = false;
1127+
1128+
Function::iterator CurrBB;
1129+
for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) {
1130+
CurrBB = BB++;
1131+
1132+
for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end();
1133+
II != IE; ++II) {
1134+
CallInst *Call = dyn_cast<CallInst>(&*II);
1135+
Function *CalledFunc;
1136+
1137+
if (!Call || !(CalledFunc = Call->getCalledFunction()))
1138+
continue;
1139+
1140+
if (Call->isNoBuiltin())
1141+
continue;
1142+
1143+
// Skip if function either has local linkage or is not a known library
1144+
// function.
1145+
LibFunc LF;
1146+
if (CalledFunc->hasLocalLinkage() || !TLI.getLibFunc(*CalledFunc, LF) ||
1147+
!TLI.has(LF))
1148+
continue;
1149+
1150+
switch (LF) {
1151+
case LibFunc_strcmp:
1152+
case LibFunc_strncmp: {
1153+
auto &DL = F.getParent()->getDataLayout();
1154+
if (StrNCmpInliner(Call, LF, BB, &DTU, DL).optimizeStrNCmp()) {
1155+
MadeCFGChange = true;
1156+
break;
1157+
}
1158+
continue;
1159+
}
1160+
default:
1161+
continue;
1162+
}
1163+
1164+
MadeChange = true;
1165+
break;
1166+
}
1167+
}
1168+
1169+
return MadeChange;
1170+
}
1171+
9251172
/// This is the entry point for folds that could be implemented in regular
9261173
/// InstCombine, but they are separated because they are not expected to
9271174
/// occur frequently and/or have more than a constant-length pattern match.
@@ -969,11 +1216,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
9691216
/// handled in the callers of this function.
9701217
static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
9711218
TargetLibraryInfo &TLI, DominatorTree &DT,
972-
AliasAnalysis &AA) {
1219+
AliasAnalysis &AA, bool &MadeCFGChange) {
9731220
bool MadeChange = false;
9741221
const DataLayout &DL = F.getParent()->getDataLayout();
9751222
TruncInstCombine TIC(AC, TLI, DL, DT);
9761223
MadeChange |= TIC.run(F);
1224+
MadeChange |= inlineLibCalls(F, TLI, TTI, DT, MadeCFGChange);
9771225
MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC);
9781226
return MadeChange;
9791227
}
@@ -985,12 +1233,16 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
9851233
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
9861234
auto &TTI = AM.getResult<TargetIRAnalysis>(F);
9871235
auto &AA = AM.getResult<AAManager>(F);
988-
if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
1236+
bool MadeCFGChange = false;
1237+
if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
9891238
// No changes, all analyses are preserved.
9901239
return PreservedAnalyses::all();
9911240
}
9921241
// Mark all the analyses that instcombine updates as preserved.
9931242
PreservedAnalyses PA;
994-
PA.preserveSet<CFGAnalyses>();
1243+
if (MadeCFGChange)
1244+
PA.preserve<DominatorTreeAnalysis>();
1245+
else
1246+
PA.preserveSet<CFGAnalyses>();
9951247
return PA;
9961248
}

0 commit comments

Comments
 (0)