Skip to content

Commit 881795f

Browse files
committed
!fixup use predicates
1 parent 7eed265 commit 881795f

File tree

2 files changed

+47
-34
lines changed

2 files changed

+47
-34
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
2-
//
1+
//===----------------------------------------------------------------------===//
32
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43
// See https://llvm.org/LICENSE.txt for license information.
54
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -15,7 +14,6 @@
1514
#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
1615

1716
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
18-
#include "llvm/IR/PatternMatch.h"
1917

2018
namespace llvm {
2119
namespace SCEVPatternMatch {
@@ -25,19 +23,49 @@ bool match(const SCEV *S, const Pattern &P) {
2523
return P.match(S);
2624
}
2725

28-
struct specific_intval64 : public PatternMatch::specific_intval64<false> {
29-
specific_intval64(uint64_t V) : PatternMatch::specific_intval64<false>(V) {}
30-
26+
template <typename Predicate> struct cst_pred_ty : public Predicate {
3127
bool match(const SCEV *S) {
32-
auto *Cast = dyn_cast<SCEVConstant>(S);
33-
return Cast &&
34-
PatternMatch::specific_intval64<false>::match(Cast->getValue());
28+
auto *C = dyn_cast<SCEVConstant>(S);
29+
return C && this->isValue(C->getAPInt());
30+
}
31+
};
32+
33+
struct is_zero_int {
34+
bool isValue(const APInt &C) { return C.isZero(); }
35+
};
36+
37+
/// Match an integer 0 or a vector with all elements equal to 0.
38+
/// For vectors, this includes constants with undefined elements.
39+
inline cst_pred_ty<is_zero_int> m_scev_ZeroInt() {
40+
return cst_pred_ty<is_zero_int>();
41+
}
42+
43+
struct is_zero {
44+
template <typename ITy> bool match(ITy *V) {
45+
auto *C = dyn_cast<SCEVConstant>(V);
46+
return C && (C->getValue()->isNullValue() ||
47+
cst_pred_ty<is_zero_int>().match(C));
3548
}
3649
};
50+
/// Match any null constant or a vector with all elements equal to 0.
51+
/// For vectors, this includes constants with undefined elements.
52+
inline is_zero m_scev_Zero() { return is_zero(); }
3753

38-
inline specific_intval64 m_scev_Zero() { return specific_intval64(0); }
39-
inline specific_intval64 m_scev_One() { return specific_intval64(1); }
40-
inline specific_intval64 m_scev_MinusOne() { return specific_intval64(-1); }
54+
struct is_one {
55+
bool isValue(const APInt &C) { return C.isOne(); }
56+
};
57+
/// Match an integer 1 or a vector with all elements equal to 1.
58+
/// For vectors, this includes constants with undefined elements.
59+
inline cst_pred_ty<is_one> m_scev_One() { return cst_pred_ty<is_one>(); }
60+
61+
struct is_all_ones {
62+
bool isValue(const APInt &C) { return C.isAllOnes(); }
63+
};
64+
/// Match an integer or vector with all bits set.
65+
/// For vectors, this includes constants with undefined elements.
66+
inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
67+
return cst_pred_ty<is_all_ones>();
68+
}
4169

4270
} // namespace SCEVPatternMatch
4371
} // namespace llvm

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -445,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
445445
llvm_unreachable("Unknown SCEV kind!");
446446
}
447447

448-
bool SCEV::isZero() const {
449-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450-
return SC->getValue()->isZero();
451-
return false;
452-
}
448+
bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
453449

454-
bool SCEV::isOne() const {
455-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456-
return SC->getValue()->isOne();
457-
return false;
458-
}
450+
bool SCEV::isOne() const { return match(this, m_scev_One()); }
459451

460-
bool SCEV::isAllOnesValue() const {
461-
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
462-
return SC->getValue()->isMinusOne();
463-
return false;
464-
}
452+
bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
465453

466454
bool SCEV::isNonConstantNegative() const {
467455
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -3425,7 +3413,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
34253413
return S;
34263414

34273415
// 0 udiv Y == 0
3428-
if (match(LHS, m_scev_Zero()))
3416+
if (match(LHS, m_scev_ZeroInt()))
34293417
return LHS;
34303418

34313419
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
@@ -10616,7 +10604,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1061610604
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
1061710605
// N = Distance (as unsigned)
1061810606

10619-
if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
10607+
if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
1062010608
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
1062110609
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
1062210610

@@ -15511,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551115499

1551215500
// If we have LHS == 0, check if LHS is computing a property of some unknown
1551315501
// SCEV %v which we can rewrite %v to express explicitly.
15514-
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15515-
if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15516-
RHSC->getValue()->isNullValue()) {
15502+
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1551715503
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1551815504
// explicitly express that.
1551915505
const SCEV *URemLHS = nullptr;
@@ -15694,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569415680
To = RHS;
1569515681
break;
1569615682
case CmpInst::ICMP_NE:
15697-
if (isa<SCEVConstant>(RHS) &&
15698-
cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15683+
if (match(RHS, m_scev_Zero())) {
1569915684
const SCEV *OneAlignedUp =
1570015685
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
1570115686
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);

0 commit comments

Comments
 (0)