Skip to content

Commit 8ea9576

Browse files
authored
[SCEV] Add initial matchers for SCEV expressions. (NFC) (llvm#119390)
This patch adds initial matchers for unary and binary SCEV expressions and specializes it for SExt, ZExt and binary add expressions. Also adds matchers for SCEVConstant and SCEVUnknown. This patch only converts a few instances to use the new matchers to make sure everything builds as expected for now. The goal of the matchers is to hopefully make it slightly easier to write code matching SCEV patterns. Depends on llvm#119389 PR: llvm#119390
1 parent c1f5937 commit 8ea9576

File tree

2 files changed

+109
-21
lines changed

2 files changed

+109
-21
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,101 @@ inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
5252
return cst_pred_ty<is_all_ones>();
5353
}
5454

55+
template <typename Class> struct class_match {
56+
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
57+
};
58+
59+
template <typename Class> struct bind_ty {
60+
Class *&VR;
61+
62+
bind_ty(Class *&V) : VR(V) {}
63+
64+
template <typename ITy> bool match(ITy *V) const {
65+
if (auto *CV = dyn_cast<Class>(V)) {
66+
VR = CV;
67+
return true;
68+
}
69+
return false;
70+
}
71+
};
72+
73+
/// Match a SCEV, capturing it if we match.
74+
inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
75+
inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
76+
return V;
77+
}
78+
inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
79+
return V;
80+
}
81+
82+
/// Match a specified const SCEV *.
83+
struct specificscev_ty {
84+
const SCEV *Expr;
85+
86+
specificscev_ty(const SCEV *Expr) : Expr(Expr) {}
87+
88+
template <typename ITy> bool match(ITy *S) { return S == Expr; }
89+
};
90+
91+
/// Match if we have a specific specified SCEV.
92+
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
93+
94+
/// Match a unary SCEV.
95+
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
96+
Op0_t Op0;
97+
98+
SCEVUnaryExpr_match(Op0_t Op0) : Op0(Op0) {}
99+
100+
bool match(const SCEV *S) {
101+
auto *E = dyn_cast<SCEVTy>(S);
102+
return E && E->getNumOperands() == 1 && Op0.match(E->getOperand(0));
103+
}
104+
};
105+
106+
template <typename SCEVTy, typename Op0_t>
107+
inline SCEVUnaryExpr_match<SCEVTy, Op0_t> m_scev_Unary(const Op0_t &Op0) {
108+
return SCEVUnaryExpr_match<SCEVTy, Op0_t>(Op0);
109+
}
110+
111+
template <typename Op0_t>
112+
inline SCEVUnaryExpr_match<SCEVSignExtendExpr, Op0_t>
113+
m_scev_SExt(const Op0_t &Op0) {
114+
return m_scev_Unary<SCEVSignExtendExpr>(Op0);
115+
}
116+
117+
template <typename Op0_t>
118+
inline SCEVUnaryExpr_match<SCEVZeroExtendExpr, Op0_t>
119+
m_scev_ZExt(const Op0_t &Op0) {
120+
return m_scev_Unary<SCEVZeroExtendExpr>(Op0);
121+
}
122+
123+
/// Match a binary SCEV.
124+
template <typename SCEVTy, typename Op0_t, typename Op1_t>
125+
struct SCEVBinaryExpr_match {
126+
Op0_t Op0;
127+
Op1_t Op1;
128+
129+
SCEVBinaryExpr_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
130+
131+
bool match(const SCEV *S) {
132+
auto *E = dyn_cast<SCEVTy>(S);
133+
return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
134+
Op1.match(E->getOperand(1));
135+
}
136+
};
137+
138+
template <typename SCEVTy, typename Op0_t, typename Op1_t>
139+
inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
140+
m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
141+
return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
142+
}
143+
144+
template <typename Op0_t, typename Op1_t>
145+
inline SCEVBinaryExpr_match<SCEVAddExpr, Op0_t, Op1_t>
146+
m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
147+
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
148+
}
149+
55150
} // namespace SCEVPatternMatch
56151
} // namespace llvm
57152

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12725,33 +12725,28 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
1272512725
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
1272612726
const SCEV *LHS, const SCEV *RHS) {
1272712727
// zext x u<= sext x, sext x s<= zext x
12728+
const SCEV *Op;
1272812729
switch (Pred) {
1272912730
case ICmpInst::ICMP_SGE:
1273012731
std::swap(LHS, RHS);
1273112732
[[fallthrough]];
1273212733
case ICmpInst::ICMP_SLE: {
12733-
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12734-
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12735-
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12736-
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12737-
return true;
12738-
break;
12734+
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12735+
return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12736+
match(RHS, m_scev_ZExt(m_Specific(Op)));
1273912737
}
1274012738
case ICmpInst::ICMP_UGE:
1274112739
std::swap(LHS, RHS);
1274212740
[[fallthrough]];
1274312741
case ICmpInst::ICMP_ULE: {
12744-
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12745-
const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12746-
const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12747-
if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12748-
return true;
12749-
break;
12742+
// If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12743+
return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12744+
match(RHS, m_scev_SExt(m_Specific(Op)));
1275012745
}
1275112746
default:
12752-
break;
12747+
return false;
1275312748
};
12754-
return false;
12749+
llvm_unreachable("unhandled case");
1275512750
}
1275612751

1275712752
bool
@@ -15417,14 +15412,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1541715412
// (X >=u C1).
1541815413
auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
1541915414
&ExprsToRewrite]() {
15420-
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15421-
if (!AddExpr || AddExpr->getNumOperands() != 2)
15422-
return false;
15423-
15424-
auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15425-
auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15415+
const SCEVConstant *C1;
15416+
const SCEVUnknown *LHSUnknown;
1542615417
auto *C2 = dyn_cast<SCEVConstant>(RHS);
15427-
if (!C1 || !C2 || !LHSUnknown)
15418+
if (!match(LHS,
15419+
m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15420+
!C2)
1542815421
return false;
1542915422

1543015423
auto ExactRegion =

0 commit comments

Comments
 (0)