Skip to content

Commit 217e0f3

Browse files
authored
[SCEV] Add initial pattern matching for SCEV constants. (NFC) (#119389)
Add initial pattern matching for SCEV constants. Follow-up patches will add additional matchers for various SCEV expressions. This patch only converts a few instances to use the new matchers to make sure everything builds as expected for now. PR: #119389
1 parent ff939b0 commit 217e0f3

File tree

2 files changed

+70
-26
lines changed

2 files changed

+70
-26
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===----------------------------------------------------------------------===//
2+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// See https://llvm.org/LICENSE.txt for license information.
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
//
6+
//===----------------------------------------------------------------------===//
7+
//
8+
// This file provides a simple and efficient mechanism for performing general
9+
// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
14+
#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
15+
16+
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
17+
18+
namespace llvm {
19+
namespace SCEVPatternMatch {
20+
21+
template <typename Val, typename Pattern>
22+
bool match(const SCEV *S, const Pattern &P) {
23+
return P.match(S);
24+
}
25+
26+
template <typename Predicate> struct cst_pred_ty : public Predicate {
27+
bool match(const SCEV *S) {
28+
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
29+
"no vector types expected from SCEVs");
30+
auto *C = dyn_cast<SCEVConstant>(S);
31+
return C && this->isValue(C->getAPInt());
32+
}
33+
};
34+
35+
struct is_zero {
36+
bool isValue(const APInt &C) { return C.isZero(); }
37+
};
38+
/// Match an integer 0.
39+
inline cst_pred_ty<is_zero> m_scev_Zero() { return cst_pred_ty<is_zero>(); }
40+
41+
struct is_one {
42+
bool isValue(const APInt &C) { return C.isOne(); }
43+
};
44+
/// Match an integer 1.
45+
inline cst_pred_ty<is_one> m_scev_One() { return cst_pred_ty<is_one>(); }
46+
47+
struct is_all_ones {
48+
bool isValue(const APInt &C) { return C.isAllOnes(); }
49+
};
50+
/// Match an integer with all bits set.
51+
inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
52+
return cst_pred_ty<is_all_ones>();
53+
}
54+
55+
} // namespace SCEVPatternMatch
56+
} // namespace llvm
57+
58+
#endif

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "llvm/Analysis/LoopInfo.h"
8080
#include "llvm/Analysis/MemoryBuiltins.h"
8181
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
82+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
8283
#include "llvm/Analysis/TargetLibraryInfo.h"
8384
#include "llvm/Analysis/ValueTracking.h"
8485
#include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
133134

134135
using namespace llvm;
135136
using namespace PatternMatch;
137+
using namespace SCEVPatternMatch;
136138

137139
#define DEBUG_TYPE "scalar-evolution"
138140

@@ -443,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
443445
llvm_unreachable("Unknown SCEV kind!");
444446
}
445447

446-
bool SCEV::isZero() const {
447-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
448-
return SC->getValue()->isZero();
449-
return false;
450-
}
448+
bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
451449

452-
bool SCEV::isOne() const {
453-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
454-
return SC->getValue()->isOne();
455-
return false;
456-
}
450+
bool SCEV::isOne() const { return match(this, m_scev_One()); }
457451

458-
bool SCEV::isAllOnesValue() const {
459-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
460-
return SC->getValue()->isMinusOne();
461-
return false;
462-
}
452+
bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
463453

464454
bool SCEV::isNonConstantNegative() const {
465455
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -3423,9 +3413,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
34233413
return S;
34243414

34253415
// 0 udiv Y == 0
3426-
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3427-
if (LHSC->getValue()->isZero())
3428-
return LHS;
3416+
if (match(LHS, m_scev_Zero()))
3417+
return LHS;
34293418

34303419
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
34313420
if (RHSC->getValue()->isOne())
@@ -10593,7 +10582,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1059310582
// Get the initial value for the loop.
1059410583
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
1059510584
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10596-
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1059710585

1059810586
if (!isLoopInvariant(Step, L))
1059910587
return getCouldNotCompute();
@@ -10615,8 +10603,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1061510603
// Handle unitary steps, which cannot wraparound.
1061610604
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
1061710605
// N = Distance (as unsigned)
10618-
if (StepC &&
10619-
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10606+
10607+
if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
1062010608
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
1062110609
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
1062210610

@@ -10668,6 +10656,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1066810656
}
1066910657

1067010658
// Solve the general equation.
10659+
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1067110660
if (!StepC || StepC->getValue()->isZero())
1067210661
return getCouldNotCompute();
1067310662
const SCEV *E = SolveLinEquationWithOverflow(
@@ -15510,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551015499

1551115500
// If we have LHS == 0, check if LHS is computing a property of some unknown
1551215501
// SCEV %v which we can rewrite %v to express explicitly.
15513-
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15514-
if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15515-
RHSC->getValue()->isNullValue()) {
15502+
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1551615503
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1551715504
// explicitly express that.
1551815505
const SCEV *URemLHS = nullptr;
@@ -15693,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569315680
To = RHS;
1569415681
break;
1569515682
case CmpInst::ICMP_NE:
15696-
if (isa<SCEVConstant>(RHS) &&
15697-
cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15683+
if (match(RHS, m_scev_Zero())) {
1569815684
const SCEV *OneAlignedUp =
1569915685
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
1570015686
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);

0 commit comments

Comments
 (0)