Skip to content

Commit 4589911

Browse files
authored
[SCEVPatternMatch] Extend with more matchers (#138836)
1 parent 2e436b1 commit 4589911

File tree

3 files changed

+62
-25
lines changed

3 files changed

+62
-25
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
2323
}
2424

2525
template <typename Predicate> struct cst_pred_ty : public Predicate {
26+
cst_pred_ty() = default;
27+
cst_pred_ty(uint64_t V) : Predicate(V) {}
2628
bool match(const SCEV *S) const {
2729
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
2830
"no vector types expected from SCEVs");
@@ -58,6 +60,8 @@ template <typename Class> struct class_match {
5860
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
5961
};
6062

63+
inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
64+
6165
template <typename Class> struct bind_ty {
6266
Class *&VR;
6367

@@ -93,6 +97,34 @@ struct specificscev_ty {
9397
/// Match if we have a specific specified SCEV.
9498
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
9599

100+
struct is_specific_cst {
101+
uint64_t CV;
102+
is_specific_cst(uint64_t C) : CV(C) {}
103+
bool isValue(const APInt &C) const { return C == CV; }
104+
};
105+
106+
/// Match an SCEV constant with a plain unsigned integer.
107+
inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
108+
109+
struct bind_cst_ty {
110+
const APInt *&CR;
111+
112+
bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
113+
114+
bool match(const SCEV *S) const {
115+
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
116+
"no vector types expected from SCEVs");
117+
auto *C = dyn_cast<SCEVConstant>(S);
118+
if (!C)
119+
return false;
120+
CR = &C->getAPInt();
121+
return true;
122+
}
123+
};
124+
125+
/// Match an SCEV constant and bind it to an APInt.
126+
inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }
127+
96128
/// Match a unary SCEV.
97129
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
98130
Op0_t Op0;
@@ -149,6 +181,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
149181
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
150182
}
151183

184+
template <typename Op0_t, typename Op1_t>
185+
inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
186+
m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
187+
return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
188+
}
189+
190+
template <typename Op0_t, typename Op1_t>
191+
inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
192+
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
193+
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
194+
}
152195
} // namespace SCEVPatternMatch
153196
} // namespace llvm
154197

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
3131
#include "llvm/Analysis/ScalarEvolution.h"
3232
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
3334
#include "llvm/Analysis/TargetLibraryInfo.h"
3435
#include "llvm/Analysis/TargetTransformInfo.h"
3536
#include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
6566
#include <vector>
6667

6768
using namespace llvm;
69+
using namespace llvm::SCEVPatternMatch;
6870

6971
#define DEBUG_TYPE "loop-accesses"
7072

@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
811813
const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
812814

813815
// Calculate the pointer stride and check if it is constant.
814-
const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
815-
if (!C) {
816+
const APInt *APStepVal;
817+
if (!match(Step, m_scev_APInt(APStepVal))) {
816818
LLVM_DEBUG({
817819
dbgs() << "LAA: Bad stride - Not a constant strided ";
818820
if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
825827
const auto &DL = Lp->getHeader()->getDataLayout();
826828
TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
827829
int64_t Size = AllocSize.getFixedValue();
828-
const APInt &APStepVal = C->getAPInt();
829830

830831
// Huge step value - give up.
831-
if (APStepVal.getBitWidth() > 64)
832+
if (APStepVal->getBitWidth() > 64)
832833
return std::nullopt;
833834

834-
int64_t StepVal = APStepVal.getSExtValue();
835+
int64_t StepVal = APStepVal->getSExtValue();
835836

836837
// Strided access.
837838
int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
20612062
DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
20622063
return Dependence::NoDep;
20632064

2064-
const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
2065-
20662065
// Attempt to prove strided accesses independent.
2067-
if (ConstDist) {
2068-
uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
2066+
const APInt *ConstDist = nullptr;
2067+
if (match(Dist, m_scev_APInt(ConstDist))) {
2068+
uint64_t Distance = ConstDist->abs().getZExtValue();
20692069

20702070
// If the distance between accesses and their strides are known constants,
20712071
// check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
21112111
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
21122112
return Dependence::Unknown;
21132113
}
2114-
if (!HasSameSize ||
2115-
couldPreventStoreLoadForward(
2116-
ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
2114+
if (!HasSameSize || couldPreventStoreLoadForward(
2115+
ConstDist->abs().getZExtValue(), TypeByteSize)) {
21172116
LLVM_DEBUG(
21182117
dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
21192118
return Dependence::ForwardButPreventsForwarding;

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
71497149
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
71507150
"Should be!");
71517151

7152-
// Peel off a constant offset:
7153-
if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7154-
// In the future we could consider being smarter here and handle
7155-
// {Start+Step,+,Step} too.
7156-
if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7157-
return;
7158-
7159-
Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7160-
S = SA->getOperand(1);
7161-
}
7152+
// Peel off a constant offset. In the future we could consider being
7153+
// smarter here and handle {Start+Step,+,Step} too.
7154+
const APInt *Off;
7155+
if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7156+
Offset = *Off;
71627157

71637158
// Peel off a cast operation
71647159
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
73377332

73387333
bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
73397334
return !SCEVExprContains(Op, [this](const SCEV *S) {
7340-
auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7335+
const SCEV *Op1;
7336+
bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
73417337
// The UDiv may be UB if the divisor is poison or zero. Unless the divisor
73427338
// is a non-zero constant, we have to assume the UDiv may be UB.
7343-
return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7344-
!isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7339+
return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
73457340
});
73467341
}
73477342

0 commit comments

Comments
 (0)