Skip to content

Commit eecf6d1

Browse files
committed
[SCEVPatternMatch] Extend with more matchers
1 parent 52b345d commit eecf6d1

File tree

3 files changed

+69
-39
lines changed

3 files changed

+69
-39
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ template <typename Class> struct class_match {
5858
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
5959
};
6060

61+
inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
62+
6163
template <typename Class> struct bind_ty {
6264
Class *&VR;
6365

@@ -93,6 +95,41 @@ struct specificscev_ty {
9395
/// Match if we have a specific specified SCEV.
9496
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
9597

98+
template <typename Class> struct cst_match {
99+
Class CV;
100+
101+
cst_match(Class Op0) : CV(Op0) {}
102+
103+
bool match(const SCEV *S) const {
104+
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
105+
"no vector types expected from SCEVs");
106+
auto *C = dyn_cast<SCEVConstant>(S);
107+
return C && C->getAPInt() == CV;
108+
}
109+
};
110+
111+
/// Match an SCEV constant with a plain unsigned integer.
112+
inline cst_match<uint64_t> m_SCEVConstant(uint64_t V) { return V; }
113+
114+
struct bind_cst_ty {
115+
const APInt *&CR;
116+
117+
bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
118+
119+
bool match(const SCEV *S) const {
120+
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
121+
"no vector types expected from SCEVs");
122+
auto *C = dyn_cast<SCEVConstant>(S);
123+
if (!C)
124+
return false;
125+
CR = &C->getAPInt();
126+
return true;
127+
}
128+
};
129+
130+
/// Match an SCEV constant and bind it to an APInt.
131+
inline bind_cst_ty m_SCEVConstant(const APInt *&C) { return C; }
132+
96133
/// Match a unary SCEV.
97134
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
98135
Op0_t Op0;
@@ -149,6 +186,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
149186
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
150187
}
151188

189+
template <typename Op0_t, typename Op1_t>
190+
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
191+
m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
192+
return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
193+
}
194+
195+
template <typename Op0_t, typename Op1_t>
196+
inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
197+
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
198+
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
199+
}
152200
} // namespace SCEVPatternMatch
153201
} // namespace llvm
154202

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
3131
#include "llvm/Analysis/ScalarEvolution.h"
3232
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
3334
#include "llvm/Analysis/TargetLibraryInfo.h"
3435
#include "llvm/Analysis/TargetTransformInfo.h"
3536
#include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
6566
#include <vector>
6667

6768
using namespace llvm;
69+
using namespace llvm::SCEVPatternMatch;
6870

6971
#define DEBUG_TYPE "loop-accesses"
7072

@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
811813
const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
812814

813815
// Calculate the pointer stride and check if it is constant.
814-
const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
815-
if (!C) {
816+
const APInt *APStepVal;
817+
if (!match(Step, m_SCEVConstant(APStepVal))) {
816818
LLVM_DEBUG({
817819
dbgs() << "LAA: Bad stride - Not a constant strided ";
818820
if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
825827
const auto &DL = Lp->getHeader()->getDataLayout();
826828
TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
827829
int64_t Size = AllocSize.getFixedValue();
828-
const APInt &APStepVal = C->getAPInt();
829830

830831
// Huge step value - give up.
831-
if (APStepVal.getBitWidth() > 64)
832+
if (APStepVal->getBitWidth() > 64)
832833
return std::nullopt;
833834

834-
int64_t StepVal = APStepVal.getSExtValue();
835+
int64_t StepVal = APStepVal->getSExtValue();
835836

836837
// Strided access.
837838
int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
20612062
DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
20622063
return Dependence::NoDep;
20632064

2064-
const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
2065-
20662065
// Attempt to prove strided accesses independent.
2067-
if (ConstDist) {
2068-
uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
2066+
const APInt *ConstDist = nullptr;
2067+
if (match(Dist, m_SCEVConstant(ConstDist))) {
2068+
uint64_t Distance = ConstDist->abs().getZExtValue();
20692069

20702070
// If the distance between accesses and their strides are known constants,
20712071
// check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
21112111
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
21122112
return Dependence::Unknown;
21132113
}
2114-
if (!HasSameSize ||
2115-
couldPreventStoreLoadForward(
2116-
ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
2114+
if (!HasSameSize || couldPreventStoreLoadForward(
2115+
ConstDist->abs().getZExtValue(), TypeByteSize)) {
21172116
LLVM_DEBUG(
21182117
dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
21192118
return Dependence::ForwardButPreventsForwarding;
@@ -2864,20 +2863,8 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
28642863

28652864
// Strip off the size of access multiplication if we are still analyzing the
28662865
// pointer.
2867-
if (OrigPtr == Ptr) {
2868-
if (auto *M = dyn_cast<SCEVMulExpr>(V)) {
2869-
auto *StepConst = dyn_cast<SCEVConstant>(M->getOperand(0));
2870-
if (!StepConst)
2871-
return nullptr;
2872-
2873-
auto StepVal = StepConst->getAPInt().trySExtValue();
2874-
// Bail out on a non-unit pointer access size.
2875-
if (!StepVal || StepVal != 1)
2876-
return nullptr;
2877-
2878-
V = M->getOperand(1);
2879-
}
2880-
}
2866+
if (OrigPtr == Ptr)
2867+
match(V, m_scev_Mul(m_SCEVConstant(1), m_SCEV(V)));
28812868

28822869
// Note that the restriction after this loop invariant check are only
28832870
// profitability restrictions.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
71497149
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
71507150
"Should be!");
71517151

7152-
// Peel off a constant offset:
7153-
if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7154-
// In the future we could consider being smarter here and handle
7155-
// {Start+Step,+,Step} too.
7156-
if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7157-
return;
7158-
7159-
Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7160-
S = SA->getOperand(1);
7161-
}
7152+
// Peel off a constant offset. In the future we could consider being
7153+
// smarter here and handle {Start+Step,+,Step} too.
7154+
const APInt *Off;
7155+
if (match(S, m_scev_Add(m_SCEVConstant(Off), m_SCEV(S))))
7156+
Offset = *Off;
71627157

71637158
// Peel off a cast operation
71647159
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
73377332

73387333
bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
73397334
return !SCEVExprContains(Op, [this](const SCEV *S) {
7340-
auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7335+
const SCEV *Op1;
7336+
bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
73417337
// The UDiv may be UB if the divisor is poison or zero. Unless the divisor
73427338
// is a non-zero constant, we have to assume the UDiv may be UB.
7343-
return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7344-
!isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7339+
return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
73457340
});
73467341
}
73477342

0 commit comments

Comments
 (0)