Skip to content

[InstCombine] Create a class to lazily track computed known bits #66611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/SimplifyQuery.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/FMF.h"
Expand Down Expand Up @@ -90,6 +91,12 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
const DominatorTree *DT = nullptr,
bool UseInstrInfo = true);

KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const SimplifyQuery &Q);

KnownBits computeKnownBits(const Value *V, unsigned Depth,
const SimplifyQuery &Q);

/// Compute known bits from the range metadata.
/// \p KnownZero the set of bits that are known to be zero
/// \p KnownOne the set of bits that are known to be one
Expand All @@ -107,7 +114,8 @@ KnownBits analyzeKnownBitsFromAndXorOr(
bool UseInstrInfo = true);

/// Return true if LHS and RHS have no common bits set.
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
bool haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
const WithCache<const Value *> &RHSCache,
const SimplifyQuery &SQ);

/// Return true if the given value is known to have exactly one bit set when
Expand Down Expand Up @@ -847,9 +855,12 @@ OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
const SimplifyQuery &SQ);
OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
const SimplifyQuery &SQ);
OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
const SimplifyQuery &SQ);
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
OverflowResult
computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const SimplifyQuery &SQ);
OverflowResult computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const SimplifyQuery &SQ);
/// This version also leverages the sign bit of Add if known.
OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,
Expand Down
71 changes: 71 additions & 0 deletions llvm/include/llvm/Analysis/WithCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- llvm/Analysis/WithCache.h - KnownBits cache for pointers -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Store a pointer to any type along with the KnownBits information for it
// that is computed lazily (if required).
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_ANALYSIS_WITHCACHE_H
#define LLVM_ANALYSIS_WITHCACHE_H

#include "llvm/IR/Value.h"
#include "llvm/Support/KnownBits.h"
#include <type_traits>

namespace llvm {
struct SimplifyQuery;
KnownBits computeKnownBits(const Value *V, unsigned Depth,
const SimplifyQuery &Q);

template <typename Arg> class WithCache {
static_assert(std::is_pointer_v<Arg>, "WithCache requires a pointer type!");

using UnderlyingType = std::remove_pointer_t<Arg>;
constexpr static bool IsConst = std::is_const_v<Arg>;

template <typename T, bool Const>
using conditionally_const_t = std::conditional_t<Const, const T, T>;

using PointerType = conditionally_const_t<UnderlyingType *, IsConst>;
using ReferenceType = conditionally_const_t<UnderlyingType &, IsConst>;

// Store the presence of the KnownBits information in one of the bits of
// Pointer.
// true -> present
// false -> absent
mutable PointerIntPair<PointerType, 1, bool> Pointer;
mutable KnownBits Known;

void calculateKnownBits(const SimplifyQuery &Q) const {
Known = computeKnownBits(Pointer.getPointer(), 0, Q);
Pointer.setInt(true);
}

public:
WithCache(PointerType Pointer) : Pointer(Pointer, false) {}
WithCache(PointerType Pointer, const KnownBits &Known)
: Pointer(Pointer, true), Known(Known) {}

[[nodiscard]] PointerType getValue() const { return Pointer.getPointer(); }

[[nodiscard]] const KnownBits &getKnownBits(const SimplifyQuery &Q) const {
if (!hasKnownBits())
calculateKnownBits(Q);
return Known;
}

[[nodiscard]] bool hasKnownBits() const { return Pointer.getInt(); }

operator PointerType() const { return Pointer.getPointer(); }
PointerType operator->() const { return Pointer.getPointer(); }
ReferenceType operator*() const { return *Pointer.getPointer(); }
};
} // namespace llvm

#endif
13 changes: 8 additions & 5 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
SQ.getWithInstruction(CxtI));
}

OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
const Value *RHS,
const Instruction *CxtI) const {
OverflowResult
computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
SQ.getWithInstruction(CxtI));
}

OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
const Instruction *CxtI) const {
OverflowResult
computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const Instruction *CxtI) const {
return llvm::computeOverflowForSignedAdd(LHS, RHS,
SQ.getWithInstruction(CxtI));
}
Expand Down
75 changes: 37 additions & 38 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -178,31 +179,29 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const SimplifyQuery &Q);

static KnownBits computeKnownBits(const Value *V, unsigned Depth,
const SimplifyQuery &Q);

KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
unsigned Depth, AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT, bool UseInstrInfo) {
return ::computeKnownBits(
return computeKnownBits(
V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
const DataLayout &DL, unsigned Depth,
AssumptionCache *AC, const Instruction *CxtI,
const DominatorTree *DT, bool UseInstrInfo) {
return ::computeKnownBits(
return computeKnownBits(
V, DemandedElts, Depth,
SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
}

bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
const WithCache<const Value *> &RHSCache,
const SimplifyQuery &SQ) {
const Value *LHS = LHSCache.getValue();
const Value *RHS = RHSCache.getValue();

assert(LHS->getType() == RHS->getType() &&
"LHS and RHS should have the same type");
assert(LHS->getType()->isIntOrIntVectorTy() &&
Expand Down Expand Up @@ -250,12 +249,9 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return true;
}
IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
KnownBits LHSKnown(IT->getBitWidth());
KnownBits RHSKnown(IT->getBitWidth());
::computeKnownBits(LHS, LHSKnown, 0, SQ);
::computeKnownBits(RHS, RHSKnown, 0, SQ);
return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown);

return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
RHSCache.getKnownBits(SQ));
}

bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
Expand Down Expand Up @@ -1784,19 +1780,19 @@ static void computeKnownBitsFromOperator(const Operator *I,

/// Determine which bits of V are known to be either zero or one and return
/// them.
KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const SimplifyQuery &Q) {
KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const SimplifyQuery &Q) {
KnownBits Known(getBitWidth(V->getType(), Q.DL));
computeKnownBits(V, DemandedElts, Known, Depth, Q);
::computeKnownBits(V, DemandedElts, Known, Depth, Q);
return Known;
}

/// Determine which bits of V are known to be either zero or one and return
/// them.
KnownBits computeKnownBits(const Value *V, unsigned Depth,
const SimplifyQuery &Q) {
KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth,
const SimplifyQuery &Q) {
KnownBits Known(getBitWidth(V->getType(), Q.DL));
computeKnownBits(V, Known, Depth, Q);
::computeKnownBits(V, Known, Depth, Q);
return Known;
}

Expand Down Expand Up @@ -6256,10 +6252,11 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {

/// Combine constant ranges from computeConstantRange() and computeKnownBits().
static ConstantRange
computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
bool ForSigned,
const SimplifyQuery &SQ) {
KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
ConstantRange CR1 =
ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would still personally be in favor of making getCR an API of its own. But not particularly important.

Copy link
Contributor

@nikic nikic Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add a getConstantRange API on this class, then what it would have to return is a cached version of computeConstantRange(), not ConstantRange::fromKnownBits(). That's something we could evaluate, but not as part of this PR. I don't think it would be worthwhile right now (because, while we do repeated computeConstantRange() calls, they use different signedness, so couldn't be reused).

ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
ConstantRange::PreferredRangeType RangeType =
ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
Expand All @@ -6269,8 +6266,8 @@ computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
Expand Down Expand Up @@ -6307,28 +6304,29 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
// product is exactly the minimum negative number.
// E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
// For simplicity we just check if at least one side is not negative.
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
return OverflowResult::NeverOverflows;
}
return OverflowResult::MayOverflow;
}

OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
OverflowResult
llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const SimplifyQuery &SQ) {
ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
ConstantRange RHSRange =
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
}

static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
const Value *RHS,
const AddOperator *Add,
const SimplifyQuery &SQ) {
static OverflowResult
computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const AddOperator *Add, const SimplifyQuery &SQ) {
if (Add && Add->hasNoSignedWrap()) {
return OverflowResult::NeverOverflows;
}
Expand Down Expand Up @@ -6944,9 +6942,10 @@ OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
Add, SQ);
}

OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
OverflowResult
llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const SimplifyQuery &SQ) {
return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
}

Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));

// A+B --> A|B iff A and B have no bits set in common.
if (haveNoCommonBitsSet(LHS, RHS, SQ.getWithInstruction(&I)))
WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
return BinaryOperator::CreateOr(LHS, RHS);

if (Instruction *Ext = narrowMathIfNoOverflow(I))
Expand Down Expand Up @@ -1661,11 +1662,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
bool Changed = false;
if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoSignedWrap(true);
}
if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
if (!I.hasNoUnsignedWrap() &&
willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoUnsignedWrap(true);
}
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final

Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext);

bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS,
bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const Instruction &CxtI) const {
return computeOverflowForSignedAdd(LHS, RHS, &CxtI) ==
OverflowResult::NeverOverflows;
}

bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS,
bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const Instruction &CxtI) const {
return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) ==
OverflowResult::NeverOverflows;
Expand Down