Skip to content

Commit 3d42d54

Browse files
committed
[ConstraintElimination] Add constraint elimination pass.
This patch is a first draft of a new pass that adds a more flexible way to eliminate compares based on more complex constraints collected from dominating conditions. In particular, it aims at simplifying conditions of the forms below using a forward propagation approach, rather than instcomine-style ad-hoc backwards walking of def-use chains. if (x < y) if (y < z) if (x < z) <- simplify or if (x + 2 < y) if (x + 1 < y) <- simplify assuming no wraps The general approach is to collect conditions and blocks, sort them by dominance and then iterate over the sorted list. Conditions are turned into a linear inequality and add it to a system containing the linear inequalities that hold on entry to the block. For blocks, we check each compare against the system and see if it is implied by the constraints in the system. We also keep a stack of processed conditions and remove conditions from the stack and the constraint system once they go out-of-scope (= do not dominate the current block any longer). Currently there still are the least the following areas for improvements * Currently large unsigned constants cannot be added to the system (coefficients must be represented as integers) * The way constraints are managed currently is not very optimized. Reviewed By: spatel Differential Revision: https://reviews.llvm.org/D84547
1 parent ca76d6e commit 3d42d54

File tree

16 files changed

+398
-60
lines changed

16 files changed

+398
-60
lines changed

llvm/include/llvm/Analysis/ConstraintSystem.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ class ConstraintSystem {
4949
Constraints.push_back(R);
5050
}
5151

52+
void addVariableRowFill(const SmallVector<int64_t, 8> &R) {
53+
for (auto &CR : Constraints) {
54+
while (CR.size() != R.size())
55+
CR.push_back(0);
56+
}
57+
addVariableRow(R);
58+
}
59+
5260
/// Returns true if there may be a solution for the constraints in the system.
5361
bool mayHaveSolution();
5462

@@ -62,6 +70,8 @@ class ConstraintSystem {
6270
}
6371

6472
bool isConditionImplied(SmallVector<int64_t, 8> R);
73+
74+
void popLastConstraint() { Constraints.pop_back(); }
6575
};
6676
} // namespace llvm
6777

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ void initializeCalledValuePropagationLegacyPassPass(PassRegistry &);
113113
void initializeCodeGenPreparePass(PassRegistry&);
114114
void initializeConstantHoistingLegacyPassPass(PassRegistry&);
115115
void initializeConstantMergeLegacyPassPass(PassRegistry&);
116+
void initializeConstraintEliminationPass(PassRegistry &);
116117
void initializeControlHeightReductionLegacyPassPass(PassRegistry&);
117118
void initializeCorrelatedValuePropagationPass(PassRegistry&);
118119
void initializeCostModelAnalysisPass(PassRegistry&);

llvm/include/llvm/Transforms/Scalar.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ Pass *createLoopDeletionPass();
340340
//
341341
FunctionPass *createConstantHoistingPass();
342342

343+
//===----------------------------------------------------------------------===//
344+
//
345+
// ConstraintElimination - This pass eliminates conditions based on found
346+
// constraints.
347+
//
348+
FunctionPass *createConstraintEliminationPass();
349+
343350
//===----------------------------------------------------------------------===//
344351
//
345352
// Sink - Code Sinking

llvm/lib/Transforms/IPO/PassManagerBuilder.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ cl::opt<bool> EnableMatrix(
153153
"enable-matrix", cl::init(false), cl::Hidden,
154154
cl::desc("Enable lowering of the matrix intrinsics"));
155155

156+
cl::opt<bool> EnableConstraintElimination(
157+
"enable-constraint-elimination", cl::init(false), cl::Hidden,
158+
cl::desc(
159+
"Enable pass to eliminate conditions based on linear constraints."));
160+
156161
cl::opt<AttributorRunOption> AttributorRun(
157162
"attributor-enable", cl::Hidden, cl::init(AttributorRunOption::NONE),
158163
cl::desc("Enable the attributor inter-procedural deduction pass."),
@@ -381,6 +386,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses(
381386
}
382387
}
383388

389+
if (EnableConstraintElimination)
390+
MPM.add(createConstraintEliminationPass());
391+
384392
if (OptLevel > 1) {
385393
// Speculative execution if the target has divergent branches; otherwise nop.
386394
MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass());

llvm/lib/Transforms/Scalar/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_llvm_component_library(LLVMScalarOpts
44
BDCE.cpp
55
CallSiteSplitting.cpp
66
ConstantHoisting.cpp
7+
ConstraintElimination.cpp
78
CorrelatedValuePropagation.cpp
89
DCE.cpp
910
DeadStoreElimination.cpp
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
//===-- ConstraintElimination.cpp - Eliminate conds using constraints. ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Eliminate conditions based on constraints collected from dominating
10+
// conditions.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "llvm/ADT/SmallVector.h"
15+
#include "llvm/ADT/Statistic.h"
16+
#include "llvm/Analysis/ConstraintSystem.h"
17+
#include "llvm/Analysis/GlobalsModRef.h"
18+
#include "llvm/IR/DataLayout.h"
19+
#include "llvm/IR/Dominators.h"
20+
#include "llvm/IR/Function.h"
21+
#include "llvm/IR/Instructions.h"
22+
#include "llvm/IR/PatternMatch.h"
23+
#include "llvm/InitializePasses.h"
24+
#include "llvm/Pass.h"
25+
#include "llvm/Support/Debug.h"
26+
#include "llvm/Support/DebugCounter.h"
27+
#include "llvm/Transforms/Scalar.h"
28+
29+
using namespace llvm;
30+
using namespace PatternMatch;
31+
32+
#define DEBUG_TYPE "constraint-elimination"
33+
34+
STATISTIC(NumCondsRemoved, "Number of instructions removed");
35+
DEBUG_COUNTER(EliminatedCounter, "conds-eliminated",
36+
"Controls which conditions are eliminated");
37+
38+
static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max();
39+
40+
Optional<std::pair<int64_t, Value *>> decompose(Value *V) {
41+
if (auto *CI = dyn_cast<ConstantInt>(V)) {
42+
if (CI->isNegative() || CI->uge(MaxConstraintValue))
43+
return {};
44+
return {{CI->getSExtValue(), nullptr}};
45+
}
46+
auto *GEP = dyn_cast<GetElementPtrInst>(V);
47+
if (GEP && GEP->getNumOperands() == 2 &&
48+
isa<ConstantInt>(GEP->getOperand(GEP->getNumOperands() - 1))) {
49+
return {{cast<ConstantInt>(GEP->getOperand(GEP->getNumOperands() - 1))
50+
->getSExtValue(),
51+
GEP->getPointerOperand()}};
52+
}
53+
return {{0, V}};
54+
}
55+
56+
/// Turn a condition \p CmpI into a constraint vector, using indices from \p
57+
/// Value2Index. If \p ShouldAdd is true, new indices are added for values not
58+
/// yet in \p Value2Index.
59+
static SmallVector<int64_t, 8>
60+
getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1,
61+
DenseMap<Value *, unsigned> &Value2Index, bool ShouldAdd) {
62+
Value *A, *B;
63+
64+
int64_t Offset1 = 0;
65+
int64_t Offset2 = 0;
66+
67+
auto TryToGetIndex = [ShouldAdd,
68+
&Value2Index](Value *V) -> Optional<unsigned> {
69+
if (ShouldAdd) {
70+
Value2Index.insert({V, Value2Index.size() + 1});
71+
return Value2Index[V];
72+
}
73+
auto I = Value2Index.find(V);
74+
if (I == Value2Index.end())
75+
return None;
76+
return I->second;
77+
};
78+
79+
if (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_UGE)
80+
return getConstraint(CmpInst::getSwappedPredicate(Pred), Op1, Op0,
81+
Value2Index, ShouldAdd);
82+
83+
if (Pred == CmpInst::ICMP_ULE || Pred == CmpInst::ICMP_ULT) {
84+
auto ADec = decompose(Op0);
85+
auto BDec = decompose(Op1);
86+
if (!ADec || !BDec)
87+
return {};
88+
std::tie(Offset1, A) = *ADec;
89+
std::tie(Offset2, B) = *BDec;
90+
Offset1 *= -1;
91+
92+
if (!A && !B)
93+
return {};
94+
95+
auto AIdx = A ? TryToGetIndex(A) : None;
96+
auto BIdx = B ? TryToGetIndex(B) : None;
97+
if ((A && !AIdx) || (B && !BIdx))
98+
return {};
99+
100+
SmallVector<int64_t, 8> R(Value2Index.size() + 1, 0);
101+
if (AIdx)
102+
R[*AIdx] = 1;
103+
if (BIdx)
104+
R[*BIdx] = -1;
105+
R[0] = Offset1 + Offset2 + (Pred == CmpInst::ICMP_ULT ? -1 : 0);
106+
return R;
107+
}
108+
109+
return {};
110+
}
111+
112+
static SmallVector<int64_t, 8>
113+
getConstraint(CmpInst *Cmp, DenseMap<Value *, unsigned> &Value2Index,
114+
bool ShouldAdd) {
115+
return getConstraint(Cmp->getPredicate(), Cmp->getOperand(0),
116+
Cmp->getOperand(1), Value2Index, ShouldAdd);
117+
}
118+
119+
/// Represents either a condition that holds on entry to a block or a basic
120+
/// block, with their respective Dominator DFS in and out numbers.
121+
struct ConstraintOrBlock {
122+
unsigned NumIn;
123+
unsigned NumOut;
124+
bool IsBlock;
125+
bool Not;
126+
union {
127+
BasicBlock *BB;
128+
CmpInst *Condition;
129+
};
130+
131+
ConstraintOrBlock(DomTreeNode *DTN)
132+
: NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(true),
133+
BB(DTN->getBlock()) {}
134+
ConstraintOrBlock(DomTreeNode *DTN, CmpInst *Condition, bool Not)
135+
: NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(false),
136+
Not(Not), Condition(Condition) {}
137+
};
138+
139+
struct StackEntry {
140+
unsigned NumIn;
141+
unsigned NumOut;
142+
CmpInst *Condition;
143+
bool IsNot;
144+
145+
StackEntry(unsigned NumIn, unsigned NumOut, CmpInst *Condition, bool IsNot)
146+
: NumIn(NumIn), NumOut(NumOut), Condition(Condition), IsNot(IsNot) {}
147+
};
148+
149+
static bool eliminateConstraints(Function &F, DominatorTree &DT) {
150+
bool Changed = false;
151+
DT.updateDFSNumbers();
152+
ConstraintSystem CS;
153+
154+
SmallVector<ConstraintOrBlock, 64> WorkList;
155+
156+
// First, collect conditions implied by branches and blocks with their
157+
// Dominator DFS in and out numbers.
158+
for (BasicBlock &BB : F) {
159+
if (!DT.getNode(&BB))
160+
continue;
161+
WorkList.emplace_back(DT.getNode(&BB));
162+
163+
auto *Br = dyn_cast<BranchInst>(BB.getTerminator());
164+
if (!Br || !Br->isConditional())
165+
continue;
166+
auto *CmpI = dyn_cast<CmpInst>(Br->getCondition());
167+
if (!CmpI)
168+
continue;
169+
if (Br->getSuccessor(0)->getSinglePredecessor())
170+
WorkList.emplace_back(DT.getNode(Br->getSuccessor(0)), CmpI, false);
171+
if (Br->getSuccessor(1)->getSinglePredecessor())
172+
WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true);
173+
}
174+
175+
// Next, sort worklist by dominance, so that dominating blocks and conditions
176+
// come before blocks and conditions dominated by them. If a block and a
177+
// condition have the same numbers, the condition comes before the block, as
178+
// it holds on entry to the block.
179+
sort(WorkList.begin(), WorkList.end(),
180+
[](const ConstraintOrBlock &A, const ConstraintOrBlock &B) {
181+
return std::tie(A.NumIn, A.IsBlock) < std::tie(B.NumIn, B.IsBlock);
182+
});
183+
184+
// Finally, process ordered worklist and eliminate implied conditions.
185+
SmallVector<StackEntry, 16> DFSInStack;
186+
DenseMap<Value *, unsigned> Value2Index;
187+
for (ConstraintOrBlock &CB : WorkList) {
188+
// First, pop entries from the stack that are out-of-scope for CB. Remove
189+
// the corresponding entry from the constraint system.
190+
while (!DFSInStack.empty()) {
191+
auto &E = DFSInStack.back();
192+
LLVM_DEBUG(dbgs() << "Top of stack : " << E.NumIn << " " << E.NumOut
193+
<< "\n");
194+
LLVM_DEBUG(dbgs() << "CB: " << CB.NumIn << " " << CB.NumOut << "\n");
195+
bool IsDom = CB.NumIn >= E.NumIn && CB.NumOut <= E.NumOut;
196+
if (IsDom)
197+
break;
198+
LLVM_DEBUG(dbgs() << "Removing " << *E.Condition << " " << E.IsNot
199+
<< "\n");
200+
DFSInStack.pop_back();
201+
CS.popLastConstraint();
202+
}
203+
204+
LLVM_DEBUG({
205+
dbgs() << "Processing ";
206+
if (CB.IsBlock)
207+
dbgs() << *CB.BB;
208+
else
209+
dbgs() << *CB.Condition;
210+
dbgs() << "\n";
211+
});
212+
213+
// For a block, check if any CmpInsts become known based on the current set
214+
// of constraints.
215+
if (CB.IsBlock) {
216+
for (Instruction &I : *CB.BB) {
217+
auto *Cmp = dyn_cast<CmpInst>(&I);
218+
if (!Cmp)
219+
continue;
220+
auto R = getConstraint(Cmp, Value2Index, false);
221+
if (R.empty())
222+
continue;
223+
if (CS.isConditionImplied(R)) {
224+
if (!DebugCounter::shouldExecute(EliminatedCounter))
225+
continue;
226+
227+
LLVM_DEBUG(dbgs() << "Condition " << *Cmp
228+
<< " implied by dominating constraints\n");
229+
LLVM_DEBUG({
230+
for (auto &E : reverse(DFSInStack))
231+
dbgs() << " C " << *E.Condition << " " << E.IsNot << "\n";
232+
});
233+
Cmp->replaceAllUsesWith(
234+
ConstantInt::getTrue(F.getParent()->getContext()));
235+
NumCondsRemoved++;
236+
Changed = true;
237+
}
238+
if (CS.isConditionImplied(ConstraintSystem::negate(R))) {
239+
if (!DebugCounter::shouldExecute(EliminatedCounter))
240+
continue;
241+
242+
LLVM_DEBUG(dbgs() << "Condition !" << *Cmp
243+
<< " implied by dominating constraints\n");
244+
LLVM_DEBUG({
245+
for (auto &E : reverse(DFSInStack))
246+
dbgs() << " C " << *E.Condition << " " << E.IsNot << "\n";
247+
});
248+
Cmp->replaceAllUsesWith(
249+
ConstantInt::getFalse(F.getParent()->getContext()));
250+
NumCondsRemoved++;
251+
Changed = true;
252+
}
253+
}
254+
continue;
255+
}
256+
257+
// Otherwise, add the condition to the system and stack, if we can transform
258+
// it into a constraint.
259+
auto R = getConstraint(CB.Condition, Value2Index, true);
260+
if (R.empty())
261+
continue;
262+
263+
LLVM_DEBUG(dbgs() << "Adding " << *CB.Condition << " " << CB.Not << "\n");
264+
if (CB.Not)
265+
R = ConstraintSystem::negate(R);
266+
267+
CS.addVariableRowFill(R);
268+
DFSInStack.emplace_back(CB.NumIn, CB.NumOut, CB.Condition, CB.Not);
269+
}
270+
271+
return Changed;
272+
}
273+
274+
namespace {
275+
276+
class ConstraintElimination : public FunctionPass {
277+
public:
278+
static char ID;
279+
280+
ConstraintElimination() : FunctionPass(ID) {
281+
initializeConstraintEliminationPass(*PassRegistry::getPassRegistry());
282+
}
283+
284+
bool runOnFunction(Function &F) override {
285+
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
286+
return eliminateConstraints(F, DT);
287+
}
288+
289+
void getAnalysisUsage(AnalysisUsage &AU) const override {
290+
AU.setPreservesCFG();
291+
AU.addRequired<DominatorTreeWrapperPass>();
292+
AU.addPreserved<GlobalsAAWrapperPass>();
293+
AU.addPreserved<DominatorTreeWrapperPass>();
294+
}
295+
};
296+
297+
} // end anonymous namespace
298+
299+
char ConstraintElimination::ID = 0;
300+
301+
INITIALIZE_PASS_BEGIN(ConstraintElimination, "constraint-elimination",
302+
"Constraint Elimination", false, false)
303+
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
304+
INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
305+
INITIALIZE_PASS_END(ConstraintElimination, "constraint-elimination",
306+
"Constraint Elimination", false, false)
307+
308+
FunctionPass *llvm::createConstraintEliminationPass() {
309+
return new ConstraintElimination();
310+
}

llvm/lib/Transforms/Scalar/Scalar.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) {
3838
initializeAlignmentFromAssumptionsPass(Registry);
3939
initializeCallSiteSplittingLegacyPassPass(Registry);
4040
initializeConstantHoistingLegacyPassPass(Registry);
41+
initializeConstraintEliminationPass(Registry);
4142
initializeCorrelatedValuePropagationPass(Registry);
4243
initializeDCELegacyPassPass(Registry);
4344
initializeDeadInstEliminationPass(Registry);

0 commit comments

Comments
 (0)