Skip to content

[ValueTracking] Filter out non-interesting conditions #118493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions llvm/include/llvm/Analysis/DomConditionCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,53 @@
#define LLVM_ANALYSIS_DOMCONDITIONCACHE_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>

namespace llvm {

class Value;
class BranchInst;

enum class DomConditionFlag : uint8_t {
None = 0,
KnownBits = 1 << 0,
KnownFPClass = 1 << 1,
PowerOfTwo = 1 << 2,
ICmp = 1 << 3,
LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/ICmp),
};

LLVM_DECLARE_ENUM_AS_BITMASK(
DomConditionFlag,
/*LargestValue=*/static_cast<uint8_t>(DomConditionFlag::ICmp));

class DomConditionCache {
private:
/// A map of values about which a branch might be providing information.
using AffectedValuesMap = DenseMap<Value *, SmallVector<BranchInst *, 1>>;
AffectedValuesMap AffectedValues;
AffectedValuesMap AffectedValues[BitWidth<DomConditionFlag>];

public:
/// Add a branch condition to the cache.
void registerBranch(BranchInst *BI);

/// Remove a value from the cache, e.g. because it will be erased.
void removeValue(Value *V) { AffectedValues.erase(V); }
void removeValue(Value *V) {
for (auto &Table : AffectedValues)
Table.erase(V);
}

/// Access the list of branches which affect this value.
ArrayRef<BranchInst *> conditionsFor(const Value *V) const {
auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
if (AVI == AffectedValues.end())
return ArrayRef<BranchInst *>();
ArrayRef<BranchInst *> conditionsFor(const Value *V,
DomConditionFlag Filter) const {
assert(has_single_bit(to_underlying(Filter)));
auto &Values = AffectedValues[countr_zero(to_underlying(Filter))];
auto AVI = Values.find_as(const_cast<Value *>(V));
if (AVI == Values.end())
return {};

return AVI->second;
}
Expand Down
8 changes: 5 additions & 3 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
#ifndef LLVM_ANALYSIS_VALUETRACKING_H
#define LLVM_ANALYSIS_VALUETRACKING_H

#include "DomConditionCache.h"
#include "llvm/Analysis/SimplifyQuery.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include <cassert>
#include <cstdint>
Expand Down Expand Up @@ -1274,8 +1275,9 @@ std::optional<bool> isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS,
/// Call \p InsertAffected on all Values whose known bits / value may be
/// affected by the condition \p Cond. Used by AssumptionCache and
/// DomConditionCache.
void findValuesAffectedByCondition(Value *Cond, bool IsAssume,
function_ref<void(Value *)> InsertAffected);
void findValuesAffectedByCondition(
Value *Cond, bool IsAssume,
function_ref<void(Value *, DomConditionFlag)> InsertAffected);

} // end namespace llvm

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Analysis/AssumptionCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
// Note: This code must be kept in-sync with the code in
// computeKnownBitsFromAssume in ValueTracking.

auto InsertAffected = [&Affected](Value *V) {
// TODO: Use DomConditionFlag to filter out non-interesting conditions.
auto InsertAffected = [&Affected](Value *V, DomConditionFlag) {
Affected.push_back({V, AssumptionCache::ExprResultIdx});
};

Expand Down
24 changes: 16 additions & 8 deletions llvm/lib/Analysis/DomConditionCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,27 @@
#include "llvm/Analysis/ValueTracking.h"
using namespace llvm;

static void findAffectedValues(Value *Cond,
SmallVectorImpl<Value *> &Affected) {
auto InsertAffected = [&Affected](Value *V) { Affected.push_back(V); };
static void findAffectedValues(
Value *Cond,
SmallVectorImpl<std::pair<Value *, DomConditionFlag>> &Affected) {
auto InsertAffected = [&Affected](Value *V, DomConditionFlag Flags) {
Affected.push_back({V, Flags});
};
findValuesAffectedByCondition(Cond, /*IsAssume=*/false, InsertAffected);
}

void DomConditionCache::registerBranch(BranchInst *BI) {
assert(BI->isConditional() && "Must be conditional branch");
SmallVector<Value *, 16> Affected;
SmallVector<std::pair<Value *, DomConditionFlag>, 16> Affected;
findAffectedValues(BI->getCondition(), Affected);
for (Value *V : Affected) {
auto &AV = AffectedValues[V];
if (!is_contained(AV, BI))
AV.push_back(BI);
for (auto [V, Flags] : Affected) {
uint32_t Underlying = to_underlying(Flags);
while (Underlying) {
uint32_t LSB = Underlying & -Underlying;
auto &AV = AffectedValues[countr_zero(LSB)][V];
if (llvm::none_of(AV, [&](BranchInst *Elem) { return Elem == BI; }))
AV.push_back(BI);
Underlying -= LSB;
}
}
}
71 changes: 38 additions & 33 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,

if (Q.DC && Q.DT) {
// Handle dominating conditions.
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
for (BranchInst *BI : Q.DC->conditionsFor(V, DomConditionFlag::KnownBits)) {
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
Expand Down Expand Up @@ -2299,7 +2299,8 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,

// Handle dominating conditions.
if (Q.DC && Q.CxtI && Q.DT) {
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
for (BranchInst *BI :
Q.DC->conditionsFor(V, DomConditionFlag::PowerOfTwo)) {
Value *Cond = BI->getCondition();

BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
Expand Down Expand Up @@ -4930,7 +4931,8 @@ static KnownFPClass computeKnownFPClassFromContext(const Value *V,

if (Q.DC && Q.DT) {
// Handle dominating conditions.
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
for (BranchInst *BI :
Q.DC->conditionsFor(V, DomConditionFlag::KnownFPClass)) {
Value *Cond = BI->getCondition();

BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
Expand Down Expand Up @@ -10014,36 +10016,38 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
return CR;
}

static void
addValueAffectedByCondition(Value *V,
function_ref<void(Value *)> InsertAffected) {
static void addValueAffectedByCondition(
Value *V, function_ref<void(Value *, DomConditionFlag)> InsertAffected,
DomConditionFlag Flags) {
assert(V != nullptr);
if (isa<Argument>(V) || isa<GlobalValue>(V)) {
InsertAffected(V);
InsertAffected(V, Flags);
} else if (auto *I = dyn_cast<Instruction>(V)) {
InsertAffected(V);
InsertAffected(V, Flags);

// Peek through unary operators to find the source of the condition.
Value *Op;
if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
if (isa<Instruction>(Op) || isa<Argument>(Op))
InsertAffected(Op);
InsertAffected(Op, Flags);
}
}
}

void llvm::findValuesAffectedByCondition(
Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
auto AddAffected = [&InsertAffected](Value *V) {
addValueAffectedByCondition(V, InsertAffected);
Value *Cond, bool IsAssume,
function_ref<void(Value *, DomConditionFlag)> InsertAffected) {
auto AddAffected = [&InsertAffected](Value *V, DomConditionFlag Flags) {
addValueAffectedByCondition(V, InsertAffected, Flags);
};

auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS,
DomConditionFlag Flags) {
if (IsAssume) {
AddAffected(LHS);
AddAffected(RHS);
AddAffected(LHS, Flags);
AddAffected(RHS, Flags);
} else if (match(RHS, m_Constant()))
AddAffected(LHS);
AddAffected(LHS, Flags);
};

SmallVector<Value *, 8> Worklist;
Expand All @@ -10058,9 +10062,9 @@ void llvm::findValuesAffectedByCondition(
Value *A, *B, *X;

if (IsAssume) {
AddAffected(V);
AddAffected(V, DomConditionFlag::KnownBits);
if (match(V, m_Not(m_Value(X))))
AddAffected(X);
AddAffected(X, DomConditionFlag::KnownBits);
}

if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
Expand All @@ -10074,7 +10078,8 @@ void llvm::findValuesAffectedByCondition(
Worklist.push_back(B);
}
} else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
AddCmpOperands(A, B);
AddCmpOperands(A, B,
DomConditionFlag::KnownBits | DomConditionFlag::ICmp);

bool HasRHSC = match(B, m_ConstantInt());
if (ICmpInst::isEquality(Pred)) {
Expand All @@ -10084,19 +10089,19 @@ void llvm::findValuesAffectedByCondition(
// (X << C) or (X >>_s C) or (X >>_u C).
if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
match(A, m_Shift(m_Value(X), m_ConstantInt())))
AddAffected(X);
AddAffected(X, DomConditionFlag::KnownBits);
else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
match(A, m_Or(m_Value(X), m_Value(Y)))) {
AddAffected(X);
AddAffected(Y);
AddAffected(X, DomConditionFlag::KnownBits);
AddAffected(Y, DomConditionFlag::KnownBits);
}
}
} else {
if (HasRHSC) {
// Handle (A + C1) u< C2, which is the canonical form of
// A > C3 && A < C4.
if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
AddAffected(X);
AddAffected(X, DomConditionFlag::KnownBits);

if (ICmpInst::isUnsigned(Pred)) {
Value *Y;
Expand All @@ -10106,42 +10111,42 @@ void llvm::findValuesAffectedByCondition(
if (match(A, m_And(m_Value(X), m_Value(Y))) ||
match(A, m_Or(m_Value(X), m_Value(Y))) ||
match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
AddAffected(X);
AddAffected(Y);
AddAffected(X, DomConditionFlag::KnownBits);
AddAffected(Y, DomConditionFlag::KnownBits);
}
// X nuw- Y u> C -> X u> C
if (match(A, m_NUWSub(m_Value(X), m_Value())))
AddAffected(X);
AddAffected(X, DomConditionFlag::KnownBits);
}
}

// Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
// by computeKnownFPClass().
if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
InsertAffected(X);
InsertAffected(X, DomConditionFlag::KnownFPClass);
else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
InsertAffected(X);
InsertAffected(X, DomConditionFlag::KnownFPClass);
}
}

if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
AddAffected(X);
AddAffected(X, DomConditionFlag::PowerOfTwo);
} else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
AddCmpOperands(A, B);
AddCmpOperands(A, B, DomConditionFlag::KnownFPClass);

// fcmp fneg(x), y
// fcmp fabs(x), y
// fcmp fneg(fabs(x)), y
if (match(A, m_FNeg(m_Value(A))))
AddAffected(A);
AddAffected(A, DomConditionFlag::KnownFPClass);
if (match(A, m_FAbs(m_Value(A))))
AddAffected(A);
AddAffected(A, DomConditionFlag::KnownFPClass);

} else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
m_Value()))) {
// Handle patterns that computeKnownFPClass() support.
AddAffected(A);
AddAffected(A, DomConditionFlag::KnownFPClass);
}
}
}
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
return nullptr;
};

for (BranchInst *BI : DC.conditionsFor(X)) {
for (BranchInst *BI : DC.conditionsFor(X, DomConditionFlag::ICmp)) {
ICmpInst::Predicate DomPred;
const APInt *DomC;
if (!match(BI->getCondition(),
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4293,9 +4293,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
(!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
// Try to simplify select arms based on KnownBits implied by the condition.
CondContext CC(CondVal);
findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
CC.AffectedValues.insert(V);
});
findValuesAffectedByCondition(
CondVal, /*IsAssume=*/false, [&](Value *V, DomConditionFlag Flags) {
if (any(Flags & DomConditionFlag::KnownBits))
CC.AffectedValues.insert(V);
});
SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
if (!CC.AffectedValues.empty()) {
if (!isa<Constant>(TrueVal) &&
Expand Down
Loading