Skip to content

Commit 652a01a

Browse files
committed
[ValueTracking] Filter out non-interesting conditions
1 parent e776484 commit 652a01a

File tree

7 files changed

+95
-51
lines changed

7 files changed

+95
-51
lines changed

llvm/include/llvm/Analysis/DomConditionCache.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,34 @@
1818
#define LLVM_ANALYSIS_DOMCONDITIONCACHE_H
1919

2020
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/ADT/BitmaskEnum.h"
2122
#include "llvm/ADT/DenseMap.h"
2223
#include "llvm/ADT/SmallVector.h"
24+
#include <cstdint>
2325

2426
namespace llvm {
2527

2628
class Value;
2729
class BranchInst;
2830

31+
enum class DomConditionFlag : uint8_t {
32+
None = 0,
33+
KnownBits = 1 << 0,
34+
KnownFPClass = 1 << 1,
35+
PowerOfTwo = 1 << 2,
36+
ICmp = 1 << 3,
37+
};
38+
39+
LLVM_DECLARE_ENUM_AS_BITMASK(
40+
DomConditionFlag,
41+
/*LargestValue=*/static_cast<uint8_t>(DomConditionFlag::ICmp));
42+
2943
class DomConditionCache {
3044
private:
3145
/// A map of values about which a branch might be providing information.
32-
using AffectedValuesMap = DenseMap<Value *, SmallVector<BranchInst *, 1>>;
46+
using AffectedValuesMap =
47+
DenseMap<Value *,
48+
SmallVector<std::pair<BranchInst *, DomConditionFlag>, 1>>;
3349
AffectedValuesMap AffectedValues;
3450

3551
public:
@@ -40,10 +56,11 @@ class DomConditionCache {
4056
void removeValue(Value *V) { AffectedValues.erase(V); }
4157

4258
/// Access the list of branches which affect this value.
43-
ArrayRef<BranchInst *> conditionsFor(const Value *V) const {
59+
ArrayRef<std::pair<BranchInst *, DomConditionFlag>>
60+
conditionsFor(const Value *V) const {
4461
auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
4562
if (AVI == AffectedValues.end())
46-
return ArrayRef<BranchInst *>();
63+
return {};
4764

4865
return AVI->second;
4966
}

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
#ifndef LLVM_ANALYSIS_VALUETRACKING_H
1515
#define LLVM_ANALYSIS_VALUETRACKING_H
1616

17+
#include "DomConditionCache.h"
1718
#include "llvm/Analysis/SimplifyQuery.h"
1819
#include "llvm/Analysis/WithCache.h"
1920
#include "llvm/IR/Constants.h"
2021
#include "llvm/IR/DataLayout.h"
2122
#include "llvm/IR/FMF.h"
22-
#include "llvm/IR/Instructions.h"
2323
#include "llvm/IR/InstrTypes.h"
24+
#include "llvm/IR/Instructions.h"
2425
#include "llvm/IR/Intrinsics.h"
2526
#include <cassert>
2627
#include <cstdint>
@@ -1275,8 +1276,9 @@ std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
12751276
/// Call \p InsertAffected on all Values whose known bits / value may be
12761277
/// affected by the condition \p Cond. Used by AssumptionCache and
12771278
/// DomConditionCache.
1278-
void findValuesAffectedByCondition(Value *Cond, bool IsAssume,
1279-
function_ref<void(Value *)> InsertAffected);
1279+
void findValuesAffectedByCondition(
1280+
Value *Cond, bool IsAssume,
1281+
function_ref<void(Value *, DomConditionFlag)> InsertAffected);
12801282

12811283
} // end namespace llvm
12821284

llvm/lib/Analysis/AssumptionCache.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
5959
// Note: This code must be kept in-sync with the code in
6060
// computeKnownBitsFromAssume in ValueTracking.
6161

62-
auto InsertAffected = [&Affected](Value *V) {
62+
// TODO: Use DomConditionFlag to filter out non-interesting conditions.
63+
auto InsertAffected = [&Affected](Value *V, DomConditionFlag) {
6364
Affected.push_back({V, AssumptionCache::ExprResultIdx});
6465
};
6566

llvm/lib/Analysis/DomConditionCache.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,30 @@
1010
#include "llvm/Analysis/ValueTracking.h"
1111
using namespace llvm;
1212

13-
static void findAffectedValues(Value *Cond,
14-
SmallVectorImpl<Value *> &Affected) {
15-
auto InsertAffected = [&Affected](Value *V) { Affected.push_back(V); };
13+
static void findAffectedValues(
14+
Value *Cond,
15+
SmallVectorImpl<std::pair<Value *, DomConditionFlag>> &Affected) {
16+
auto InsertAffected = [&Affected](Value *V, DomConditionFlag Flags) {
17+
Affected.push_back({V, Flags});
18+
};
1619
findValuesAffectedByCondition(Cond, /*IsAssume=*/false, InsertAffected);
1720
}
1821

1922
void DomConditionCache::registerBranch(BranchInst *BI) {
2023
assert(BI->isConditional() && "Must be conditional branch");
21-
SmallVector<Value *, 16> Affected;
24+
SmallVector<std::pair<Value *, DomConditionFlag>, 16> Affected;
2225
findAffectedValues(BI->getCondition(), Affected);
23-
for (Value *V : Affected) {
26+
for (auto [V, Flags] : Affected) {
2427
auto &AV = AffectedValues[V];
25-
if (!is_contained(AV, BI))
26-
AV.push_back(BI);
28+
bool Exist = false;
29+
for (auto &[OtherBI, OtherFlags] : AV) {
30+
if (OtherBI == BI) {
31+
OtherFlags |= Flags;
32+
Exist = true;
33+
break;
34+
}
35+
}
36+
if (!Exist)
37+
AV.push_back({BI, Flags});
2738
}
2839
}

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,9 @@ void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
790790

791791
if (Q.DC && Q.DT) {
792792
// Handle dominating conditions.
793-
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
793+
for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
794+
if (!any(Flag & DomConditionFlag::KnownBits))
795+
continue;
794796
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
795797
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
796798
computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
@@ -2299,7 +2301,9 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
22992301

23002302
// Handle dominating conditions.
23012303
if (Q.DC && Q.CxtI && Q.DT) {
2302-
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
2304+
for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
2305+
if (!any(Flag & DomConditionFlag::PowerOfTwo))
2306+
continue;
23032307
Value *Cond = BI->getCondition();
23042308

23052309
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -4930,7 +4934,9 @@ static KnownFPClass computeKnownFPClassFromContext(const Value *V,
49304934

49314935
if (Q.DC && Q.DT) {
49324936
// Handle dominating conditions.
4933-
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4937+
for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
4938+
if (!any(Flag & DomConditionFlag::KnownFPClass))
4939+
continue;
49344940
Value *Cond = BI->getCondition();
49354941

49364942
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -10014,36 +10020,38 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
1001410020
return CR;
1001510021
}
1001610022

10017-
static void
10018-
addValueAffectedByCondition(Value *V,
10019-
function_ref<void(Value *)> InsertAffected) {
10023+
static void addValueAffectedByCondition(
10024+
Value *V, function_ref<void(Value *, DomConditionFlag)> InsertAffected,
10025+
DomConditionFlag Flags) {
1002010026
assert(V != nullptr);
1002110027
if (isa<Argument>(V) || isa<GlobalValue>(V)) {
10022-
InsertAffected(V);
10028+
InsertAffected(V, Flags);
1002310029
} else if (auto *I = dyn_cast<Instruction>(V)) {
10024-
InsertAffected(V);
10030+
InsertAffected(V, Flags);
1002510031

1002610032
// Peek through unary operators to find the source of the condition.
1002710033
Value *Op;
1002810034
if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
1002910035
if (isa<Instruction>(Op) || isa<Argument>(Op))
10030-
InsertAffected(Op);
10036+
InsertAffected(Op, Flags);
1003110037
}
1003210038
}
1003310039
}
1003410040

1003510041
void llvm::findValuesAffectedByCondition(
10036-
Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10037-
auto AddAffected = [&InsertAffected](Value *V) {
10038-
addValueAffectedByCondition(V, InsertAffected);
10042+
Value *Cond, bool IsAssume,
10043+
function_ref<void(Value *, DomConditionFlag)> InsertAffected) {
10044+
auto AddAffected = [&InsertAffected](Value *V, DomConditionFlag Flags) {
10045+
addValueAffectedByCondition(V, InsertAffected, Flags);
1003910046
};
1004010047

10041-
auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10048+
auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS,
10049+
DomConditionFlag Flags) {
1004210050
if (IsAssume) {
10043-
AddAffected(LHS);
10044-
AddAffected(RHS);
10051+
AddAffected(LHS, Flags);
10052+
AddAffected(RHS, Flags);
1004510053
} else if (match(RHS, m_Constant()))
10046-
AddAffected(LHS);
10054+
AddAffected(LHS, Flags);
1004710055
};
1004810056

1004910057
SmallVector<Value *, 8> Worklist;
@@ -10058,9 +10066,9 @@ void llvm::findValuesAffectedByCondition(
1005810066
Value *A, *B, *X;
1005910067

1006010068
if (IsAssume) {
10061-
AddAffected(V);
10069+
AddAffected(V, DomConditionFlag::KnownBits);
1006210070
if (match(V, m_Not(m_Value(X))))
10063-
AddAffected(X);
10071+
AddAffected(X, DomConditionFlag::KnownBits);
1006410072
}
1006510073

1006610074
if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
@@ -10074,7 +10082,8 @@ void llvm::findValuesAffectedByCondition(
1007410082
Worklist.push_back(B);
1007510083
}
1007610084
} else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
10077-
AddCmpOperands(A, B);
10085+
AddCmpOperands(A, B,
10086+
DomConditionFlag::KnownBits | DomConditionFlag::ICmp);
1007810087

1007910088
bool HasRHSC = match(B, m_ConstantInt());
1008010089
if (ICmpInst::isEquality(Pred)) {
@@ -10084,19 +10093,19 @@ void llvm::findValuesAffectedByCondition(
1008410093
// (X << C) or (X >>_s C) or (X >>_u C).
1008510094
if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
1008610095
match(A, m_Shift(m_Value(X), m_ConstantInt())))
10087-
AddAffected(X);
10096+
AddAffected(X, DomConditionFlag::KnownBits);
1008810097
else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
1008910098
match(A, m_Or(m_Value(X), m_Value(Y)))) {
10090-
AddAffected(X);
10091-
AddAffected(Y);
10099+
AddAffected(X, DomConditionFlag::KnownBits);
10100+
AddAffected(Y, DomConditionFlag::KnownBits);
1009210101
}
1009310102
}
1009410103
} else {
1009510104
if (HasRHSC) {
1009610105
// Handle (A + C1) u< C2, which is the canonical form of
1009710106
// A > C3 && A < C4.
1009810107
if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
10099-
AddAffected(X);
10108+
AddAffected(X, DomConditionFlag::KnownBits);
1010010109

1010110110
if (ICmpInst::isUnsigned(Pred)) {
1010210111
Value *Y;
@@ -10106,42 +10115,42 @@ void llvm::findValuesAffectedByCondition(
1010610115
if (match(A, m_And(m_Value(X), m_Value(Y))) ||
1010710116
match(A, m_Or(m_Value(X), m_Value(Y))) ||
1010810117
match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
10109-
AddAffected(X);
10110-
AddAffected(Y);
10118+
AddAffected(X, DomConditionFlag::KnownBits);
10119+
AddAffected(Y, DomConditionFlag::KnownBits);
1011110120
}
1011210121
// X nuw- Y u> C -> X u> C
1011310122
if (match(A, m_NUWSub(m_Value(X), m_Value())))
10114-
AddAffected(X);
10123+
AddAffected(X, DomConditionFlag::KnownBits);
1011510124
}
1011610125
}
1011710126

1011810127
// Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
1011910128
// by computeKnownFPClass().
1012010129
if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
1012110130
if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
10122-
InsertAffected(X);
10131+
InsertAffected(X, DomConditionFlag::KnownFPClass);
1012310132
else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
10124-
InsertAffected(X);
10133+
InsertAffected(X, DomConditionFlag::KnownFPClass);
1012510134
}
1012610135
}
1012710136

1012810137
if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
10129-
AddAffected(X);
10138+
AddAffected(X, DomConditionFlag::PowerOfTwo);
1013010139
} else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
10131-
AddCmpOperands(A, B);
10140+
AddCmpOperands(A, B, DomConditionFlag::KnownFPClass);
1013210141

1013310142
// fcmp fneg(x), y
1013410143
// fcmp fabs(x), y
1013510144
// fcmp fneg(fabs(x)), y
1013610145
if (match(A, m_FNeg(m_Value(A))))
10137-
AddAffected(A);
10146+
AddAffected(A, DomConditionFlag::KnownFPClass);
1013810147
if (match(A, m_FAbs(m_Value(A))))
10139-
AddAffected(A);
10148+
AddAffected(A, DomConditionFlag::KnownFPClass);
1014010149

1014110150
} else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
1014210151
m_Value()))) {
1014310152
// Handle patterns that computeKnownFPClass() support.
10144-
AddAffected(A);
10153+
AddAffected(A, DomConditionFlag::KnownFPClass);
1014510154
}
1014610155
}
1014710156
}

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1385,7 +1385,9 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
13851385
return nullptr;
13861386
};
13871387

1388-
for (BranchInst *BI : DC.conditionsFor(X)) {
1388+
for (auto [BI, Flags] : DC.conditionsFor(X)) {
1389+
if (!any(Flags & DomConditionFlag::ICmp))
1390+
continue;
13891391
ICmpInst::Predicate DomPred;
13901392
const APInt *DomC;
13911393
if (!match(BI->getCondition(),

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4293,9 +4293,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
42934293
(!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
42944294
// Try to simplify select arms based on KnownBits implied by the condition.
42954295
CondContext CC(CondVal);
4296-
findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
4297-
CC.AffectedValues.insert(V);
4298-
});
4296+
findValuesAffectedByCondition(
4297+
CondVal, /*IsAssume=*/false, [&](Value *V, DomConditionFlag Flags) {
4298+
if (any(Flags & DomConditionFlag::KnownBits))
4299+
CC.AffectedValues.insert(V);
4300+
});
42994301
SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
43004302
if (!CC.AffectedValues.empty()) {
43014303
if (!isa<Constant>(TrueVal) &&

0 commit comments

Comments
 (0)