79
79
#include "llvm/Analysis/LoopInfo.h"
80
80
#include "llvm/Analysis/MemoryBuiltins.h"
81
81
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
82
+ #include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
82
83
#include "llvm/Analysis/TargetLibraryInfo.h"
83
84
#include "llvm/Analysis/ValueTracking.h"
84
85
#include "llvm/Config/llvm-config.h"
133
134
134
135
using namespace llvm;
135
136
using namespace PatternMatch;
137
+ using namespace SCEVPatternMatch;
136
138
137
139
#define DEBUG_TYPE "scalar-evolution"
138
140
@@ -443,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
443
445
llvm_unreachable("Unknown SCEV kind!");
444
446
}
445
447
446
- bool SCEV::isZero() const {
447
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
448
- return SC->getValue()->isZero();
449
- return false;
450
- }
448
+ bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
451
449
452
- bool SCEV::isOne() const {
453
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
454
- return SC->getValue()->isOne();
455
- return false;
456
- }
450
+ bool SCEV::isOne() const { return match(this, m_scev_One()); }
457
451
458
- bool SCEV::isAllOnesValue() const {
459
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
460
- return SC->getValue()->isMinusOne();
461
- return false;
462
- }
452
+ bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
463
453
464
454
bool SCEV::isNonConstantNegative() const {
465
455
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -3423,9 +3413,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3423
3413
return S;
3424
3414
3425
3415
// 0 udiv Y == 0
3426
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3427
- if (LHSC->getValue()->isZero())
3428
- return LHS;
3416
+ if (match(LHS, m_scev_Zero()))
3417
+ return LHS;
3429
3418
3430
3419
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3431
3420
if (RHSC->getValue()->isOne())
@@ -10593,7 +10582,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10593
10582
// Get the initial value for the loop.
10594
10583
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10595
10584
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10596
- const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10597
10585
10598
10586
if (!isLoopInvariant(Step, L))
10599
10587
return getCouldNotCompute();
@@ -10615,8 +10603,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10615
10603
// Handle unitary steps, which cannot wraparound.
10616
10604
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
10617
10605
// N = Distance (as unsigned)
10618
- if (StepC &&
10619
- (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne( ))) {
10606
+
10607
+ if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes() ))) {
10620
10608
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10621
10609
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10622
10610
@@ -10668,6 +10656,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10668
10656
}
10669
10657
10670
10658
// Solve the general equation.
10659
+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10671
10660
if (!StepC || StepC->getValue()->isZero())
10672
10661
return getCouldNotCompute();
10673
10662
const SCEV *E = SolveLinEquationWithOverflow(
@@ -15510,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
15510
15499
15511
15500
// If we have LHS == 0, check if LHS is computing a property of some unknown
15512
15501
// SCEV %v which we can rewrite %v to express explicitly.
15513
- const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15514
- if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15515
- RHSC->getValue()->isNullValue()) {
15502
+ if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15516
15503
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15517
15504
// explicitly express that.
15518
15505
const SCEV *URemLHS = nullptr;
@@ -15693,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
15693
15680
To = RHS;
15694
15681
break;
15695
15682
case CmpInst::ICMP_NE:
15696
- if (isa<SCEVConstant>(RHS) &&
15697
- cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15683
+ if (match(RHS, m_scev_Zero())) {
15698
15684
const SCEV *OneAlignedUp =
15699
15685
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15700
15686
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
0 commit comments