Skip to content

Commit dc540d1

Browse files
committed
[SCEV] Add initial pattern matching for SCEV constants. (NFC)
Add initial pattern matching for SCEV constants. Follow-up patches will add additional matchers for various SCEV expressions.
1 parent ecbf64d commit dc540d1

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides a simple and efficient mechanism for performing general
10+
// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
15+
#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
16+
17+
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
18+
19+
namespace llvm {
20+
namespace SCEVPatternMatch {
21+
22+
template <typename Val, typename Pattern>
23+
bool match(const SCEV *S, const Pattern &P) {
24+
return P.match(S);
25+
}
26+
27+
/// Match a specified integer value. \p BitWidth optionally specifies the
28+
/// bitwidth the matched constant must have. If it is 0, the matched constant
29+
/// can have any bitwidth.
30+
template <unsigned BitWidth = 0> struct specific_intval {
31+
APInt Val;
32+
33+
specific_intval(APInt V) : Val(std::move(V)) {}
34+
35+
bool match(const SCEV *S) const {
36+
const auto *C = dyn_cast<SCEVConstant>(S);
37+
if (!C)
38+
return false;
39+
40+
if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
41+
return false;
42+
return APInt::isSameValue(C->getAPInt(), Val);
43+
}
44+
};
45+
46+
inline specific_intval<0> m_scev_Zero() {
47+
return specific_intval<0>(APInt(64, 0));
48+
}
49+
inline specific_intval<0> m_scev_One() {
50+
return specific_intval<0>(APInt(64, 1));
51+
}
52+
inline specific_intval<0> m_scev_MinusOne() {
53+
return specific_intval<0>(APInt(64, -1));
54+
}
55+
56+
} // namespace SCEVPatternMatch
57+
} // namespace llvm
58+
59+
#endif

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "llvm/Analysis/LoopInfo.h"
8080
#include "llvm/Analysis/MemoryBuiltins.h"
8181
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
82+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
8283
#include "llvm/Analysis/TargetLibraryInfo.h"
8384
#include "llvm/Analysis/ValueTracking.h"
8485
#include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
133134

134135
using namespace llvm;
135136
using namespace PatternMatch;
137+
using namespace SCEVPatternMatch;
136138

137139
#define DEBUG_TYPE "scalar-evolution"
138140

@@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
34233425
return S;
34243426

34253427
// 0 udiv Y == 0
3426-
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3427-
if (LHSC->getValue()->isZero())
3428-
return LHS;
3428+
if (match(LHS, m_scev_Zero()))
3429+
return LHS;
34293430

34303431
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
34313432
if (RHSC->getValue()->isOne())
@@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1059310594
// Get the initial value for the loop.
1059410595
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
1059510596
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10596-
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1059710597

1059810598
if (!isLoopInvariant(Step, L))
1059910599
return getCouldNotCompute();
@@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1061510615
// Handle unitary steps, which cannot wraparound.
1061610616
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
1061710617
// N = Distance (as unsigned)
10618-
if (StepC &&
10619-
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10618+
10619+
if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
1062010620
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
1062110621
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
1062210622

@@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1066810668
}
1066910669

1067010670
// Solve the general equation.
10671+
const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
1067110672
if (!StepC || StepC->getValue()->isZero())
1067210673
return getCouldNotCompute();
1067310674
const SCEV *E = SolveLinEquationWithOverflow(

0 commit comments

Comments
 (0)