Skip to content

Commit be57381

Browse files
authored
[InstCombine] Create a class to lazily track computed known bits (#66611)
This patch adds a new class "WithCache" which stores a pointer to any type passable to computeKnownBits along with KnownBits information which is computed on-demand when getKnownBits() is called. This allows reusing the known bits information when it is passed as an argument to multiple functions. It also changes a few functions to accept WithCache(s) so that known bits information computed in some callees can be propagated to others from the top level visitAddSub caller. This gives a speedup of 0.14%: https://llvm-compile-time-tracker.com/compare.php?from=499d41cef2e7bbb65804f6a815b9fa8b27efce0f&to=fbea87f1f1e6d5552e2bc309f8e201a3af6d28ec&stat=instructions:u
1 parent 7b1e685 commit be57381

File tree

6 files changed

+140
-52
lines changed

6 files changed

+140
-52
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/ArrayRef.h"
1818
#include "llvm/ADT/SmallSet.h"
1919
#include "llvm/Analysis/SimplifyQuery.h"
20+
#include "llvm/Analysis/WithCache.h"
2021
#include "llvm/IR/Constants.h"
2122
#include "llvm/IR/DataLayout.h"
2223
#include "llvm/IR/FMF.h"
@@ -90,6 +91,12 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
9091
const DominatorTree *DT = nullptr,
9192
bool UseInstrInfo = true);
9293

94+
KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
95+
unsigned Depth, const SimplifyQuery &Q);
96+
97+
KnownBits computeKnownBits(const Value *V, unsigned Depth,
98+
const SimplifyQuery &Q);
99+
93100
/// Compute known bits from the range metadata.
94101
/// \p KnownZero the set of bits that are known to be zero
95102
/// \p KnownOne the set of bits that are known to be one
@@ -107,7 +114,8 @@ KnownBits analyzeKnownBitsFromAndXorOr(
107114
bool UseInstrInfo = true);
108115

109116
/// Return true if LHS and RHS have no common bits set.
110-
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
117+
bool haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
118+
const WithCache<const Value *> &RHSCache,
111119
const SimplifyQuery &SQ);
112120

113121
/// Return true if the given value is known to have exactly one bit set when
@@ -847,9 +855,12 @@ OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
847855
const SimplifyQuery &SQ);
848856
OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
849857
const SimplifyQuery &SQ);
850-
OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
851-
const SimplifyQuery &SQ);
852-
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
858+
OverflowResult
859+
computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
860+
const WithCache<const Value *> &RHS,
861+
const SimplifyQuery &SQ);
862+
OverflowResult computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
863+
const WithCache<const Value *> &RHS,
853864
const SimplifyQuery &SQ);
854865
/// This version also leverages the sign bit of Add if known.
855866
OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- llvm/Analysis/WithCache.h - KnownBits cache for pointers -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Store a pointer to any type along with the KnownBits information for it
10+
// that is computed lazily (if required).
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_ANALYSIS_WITHCACHE_H
15+
#define LLVM_ANALYSIS_WITHCACHE_H
16+
17+
#include "llvm/IR/Value.h"
18+
#include "llvm/Support/KnownBits.h"
19+
#include <type_traits>
20+
21+
namespace llvm {
22+
struct SimplifyQuery;
23+
KnownBits computeKnownBits(const Value *V, unsigned Depth,
24+
const SimplifyQuery &Q);
25+
26+
template <typename Arg> class WithCache {
27+
static_assert(std::is_pointer_v<Arg>, "WithCache requires a pointer type!");
28+
29+
using UnderlyingType = std::remove_pointer_t<Arg>;
30+
constexpr static bool IsConst = std::is_const_v<Arg>;
31+
32+
template <typename T, bool Const>
33+
using conditionally_const_t = std::conditional_t<Const, const T, T>;
34+
35+
using PointerType = conditionally_const_t<UnderlyingType *, IsConst>;
36+
using ReferenceType = conditionally_const_t<UnderlyingType &, IsConst>;
37+
38+
// Store the presence of the KnownBits information in one of the bits of
39+
// Pointer.
40+
// true -> present
41+
// false -> absent
42+
mutable PointerIntPair<PointerType, 1, bool> Pointer;
43+
mutable KnownBits Known;
44+
45+
void calculateKnownBits(const SimplifyQuery &Q) const {
46+
Known = computeKnownBits(Pointer.getPointer(), 0, Q);
47+
Pointer.setInt(true);
48+
}
49+
50+
public:
51+
WithCache(PointerType Pointer) : Pointer(Pointer, false) {}
52+
WithCache(PointerType Pointer, const KnownBits &Known)
53+
: Pointer(Pointer, true), Known(Known) {}
54+
55+
[[nodiscard]] PointerType getValue() const { return Pointer.getPointer(); }
56+
57+
[[nodiscard]] const KnownBits &getKnownBits(const SimplifyQuery &Q) const {
58+
if (!hasKnownBits())
59+
calculateKnownBits(Q);
60+
return Known;
61+
}
62+
63+
[[nodiscard]] bool hasKnownBits() const { return Pointer.getInt(); }
64+
65+
operator PointerType() const { return Pointer.getPointer(); }
66+
PointerType operator->() const { return Pointer.getPointer(); }
67+
ReferenceType operator*() const { return *Pointer.getPointer(); }
68+
};
69+
} // namespace llvm
70+
71+
#endif

llvm/include/llvm/Transforms/InstCombine/InstCombiner.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,15 +510,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
510510
SQ.getWithInstruction(CxtI));
511511
}
512512

513-
OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
514-
const Value *RHS,
515-
const Instruction *CxtI) const {
513+
OverflowResult
514+
computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
515+
const WithCache<const Value *> &RHS,
516+
const Instruction *CxtI) const {
516517
return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
517518
SQ.getWithInstruction(CxtI));
518519
}
519520

520-
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
521-
const Instruction *CxtI) const {
521+
OverflowResult
522+
computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
523+
const WithCache<const Value *> &RHS,
524+
const Instruction *CxtI) const {
522525
return llvm::computeOverflowForSignedAdd(LHS, RHS,
523526
SQ.getWithInstruction(CxtI));
524527
}

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
3434
#include "llvm/Analysis/TargetLibraryInfo.h"
3535
#include "llvm/Analysis/VectorUtils.h"
36+
#include "llvm/Analysis/WithCache.h"
3637
#include "llvm/IR/Argument.h"
3738
#include "llvm/IR/Attributes.h"
3839
#include "llvm/IR/BasicBlock.h"
@@ -178,31 +179,29 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
178179
SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
179180
}
180181

181-
static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
182-
unsigned Depth, const SimplifyQuery &Q);
183-
184-
static KnownBits computeKnownBits(const Value *V, unsigned Depth,
185-
const SimplifyQuery &Q);
186-
187182
KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
188183
unsigned Depth, AssumptionCache *AC,
189184
const Instruction *CxtI,
190185
const DominatorTree *DT, bool UseInstrInfo) {
191-
return ::computeKnownBits(
186+
return computeKnownBits(
192187
V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
193188
}
194189

195190
KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
196191
const DataLayout &DL, unsigned Depth,
197192
AssumptionCache *AC, const Instruction *CxtI,
198193
const DominatorTree *DT, bool UseInstrInfo) {
199-
return ::computeKnownBits(
194+
return computeKnownBits(
200195
V, DemandedElts, Depth,
201196
SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
202197
}
203198

204-
bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
199+
bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
200+
const WithCache<const Value *> &RHSCache,
205201
const SimplifyQuery &SQ) {
202+
const Value *LHS = LHSCache.getValue();
203+
const Value *RHS = RHSCache.getValue();
204+
206205
assert(LHS->getType() == RHS->getType() &&
207206
"LHS and RHS should have the same type");
208207
assert(LHS->getType()->isIntOrIntVectorTy() &&
@@ -250,12 +249,9 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
250249
match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
251250
return true;
252251
}
253-
IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
254-
KnownBits LHSKnown(IT->getBitWidth());
255-
KnownBits RHSKnown(IT->getBitWidth());
256-
::computeKnownBits(LHS, LHSKnown, 0, SQ);
257-
::computeKnownBits(RHS, RHSKnown, 0, SQ);
258-
return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown);
252+
253+
return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
254+
RHSCache.getKnownBits(SQ));
259255
}
260256

261257
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
@@ -1784,19 +1780,19 @@ static void computeKnownBitsFromOperator(const Operator *I,
17841780

17851781
/// Determine which bits of V are known to be either zero or one and return
17861782
/// them.
1787-
KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
1788-
unsigned Depth, const SimplifyQuery &Q) {
1783+
KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
1784+
unsigned Depth, const SimplifyQuery &Q) {
17891785
KnownBits Known(getBitWidth(V->getType(), Q.DL));
1790-
computeKnownBits(V, DemandedElts, Known, Depth, Q);
1786+
::computeKnownBits(V, DemandedElts, Known, Depth, Q);
17911787
return Known;
17921788
}
17931789

17941790
/// Determine which bits of V are known to be either zero or one and return
17951791
/// them.
1796-
KnownBits computeKnownBits(const Value *V, unsigned Depth,
1797-
const SimplifyQuery &Q) {
1792+
KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth,
1793+
const SimplifyQuery &Q) {
17981794
KnownBits Known(getBitWidth(V->getType(), Q.DL));
1799-
computeKnownBits(V, Known, Depth, Q);
1795+
::computeKnownBits(V, Known, Depth, Q);
18001796
return Known;
18011797
}
18021798

@@ -6256,10 +6252,11 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
62566252

62576253
/// Combine constant ranges from computeConstantRange() and computeKnownBits().
62586254
static ConstantRange
6259-
computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
6255+
computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
6256+
bool ForSigned,
62606257
const SimplifyQuery &SQ) {
6261-
KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
6262-
ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
6258+
ConstantRange CR1 =
6259+
ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
62636260
ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
62646261
ConstantRange::PreferredRangeType RangeType =
62656262
ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
@@ -6269,8 +6266,8 @@ computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
62696266
OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
62706267
const Value *RHS,
62716268
const SimplifyQuery &SQ) {
6272-
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
6273-
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
6269+
KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
6270+
KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
62746271
ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
62756272
ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
62766273
return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
@@ -6307,28 +6304,29 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
63076304
// product is exactly the minimum negative number.
63086305
// E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
63096306
// For simplicity we just check if at least one side is not negative.
6310-
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
6311-
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
6307+
KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
6308+
KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
63126309
if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
63136310
return OverflowResult::NeverOverflows;
63146311
}
63156312
return OverflowResult::MayOverflow;
63166313
}
63176314

6318-
OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
6319-
const Value *RHS,
6320-
const SimplifyQuery &SQ) {
6315+
OverflowResult
6316+
llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
6317+
const WithCache<const Value *> &RHS,
6318+
const SimplifyQuery &SQ) {
63216319
ConstantRange LHSRange =
63226320
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
63236321
ConstantRange RHSRange =
63246322
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
63256323
return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
63266324
}
63276325

6328-
static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
6329-
const Value *RHS,
6330-
const AddOperator *Add,
6331-
const SimplifyQuery &SQ) {
6326+
static OverflowResult
6327+
computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
6328+
const WithCache<const Value *> &RHS,
6329+
const AddOperator *Add, const SimplifyQuery &SQ) {
63326330
if (Add && Add->hasNoSignedWrap()) {
63336331
return OverflowResult::NeverOverflows;
63346332
}
@@ -6944,9 +6942,10 @@ OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
69446942
Add, SQ);
69456943
}
69466944

6947-
OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
6948-
const Value *RHS,
6949-
const SimplifyQuery &SQ) {
6945+
OverflowResult
6946+
llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
6947+
const WithCache<const Value *> &RHS,
6948+
const SimplifyQuery &SQ) {
69506949
return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
69516950
}
69526951

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15661566
return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
15671567

15681568
// A+B --> A|B iff A and B have no bits set in common.
1569-
if (haveNoCommonBitsSet(LHS, RHS, SQ.getWithInstruction(&I)))
1569+
WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
1570+
if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
15701571
return BinaryOperator::CreateOr(LHS, RHS);
15711572

15721573
if (Instruction *Ext = narrowMathIfNoOverflow(I))
@@ -1661,11 +1662,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
16611662
// willNotOverflowUnsignedAdd to reduce the number of invocations of
16621663
// computeKnownBits.
16631664
bool Changed = false;
1664-
if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
1665+
if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
16651666
Changed = true;
16661667
I.setHasNoSignedWrap(true);
16671668
}
1668-
if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
1669+
if (!I.hasNoUnsignedWrap() &&
1670+
willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
16691671
Changed = true;
16701672
I.setHasNoUnsignedWrap(true);
16711673
}

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
295295

296296
Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext);
297297

298-
bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS,
298+
bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS,
299+
const WithCache<const Value *> &RHS,
299300
const Instruction &CxtI) const {
300301
return computeOverflowForSignedAdd(LHS, RHS, &CxtI) ==
301302
OverflowResult::NeverOverflows;
302303
}
303304

304-
bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS,
305+
bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS,
306+
const WithCache<const Value *> &RHS,
305307
const Instruction &CxtI) const {
306308
return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) ==
307309
OverflowResult::NeverOverflows;

0 commit comments

Comments
 (0)