Skip to content

Commit 98fe785

Browse files
committed
[IR][PatternMatch] Make m_Checked{Int,Fp} accept Constant * output instead of APInt *
The `APInt *` version is pretty useless as any case one needs an `APInt *` out, they could just replace whatever they have the `m_Checked...` lambda with direct checks on the `APInt`. Leaving other helpers such as `m_Negative`, `m_Power2`, etc... unchanged as the `APInt` out version is used mostly for convenience and rarely change functionality when converted output a `Constant *`.
1 parent 45fed80 commit 98fe785

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
354354
/// is true.
355355
template <typename Predicate, typename ConstantVal, bool AllowPoison>
356356
struct cstval_pred_ty : public Predicate {
357-
template <typename ITy> bool match(ITy *V) {
357+
const Constant **Res = nullptr;
358+
template <typename ITy> bool match_impl(ITy *V) {
358359
if (const auto *CV = dyn_cast<ConstantVal>(V))
359360
return this->isValue(CV->getValue());
360361
if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
@@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
387388
}
388389
return false;
389390
}
391+
392+
template <typename ITy> bool match(ITy *V) {
393+
if (this->match_impl(V)) {
394+
if (Res)
395+
*Res = cast<Constant>(V);
396+
return true;
397+
}
398+
return false;
399+
}
390400
};
391401

392402
/// specialization of cstval_pred_ty for ConstantInt
@@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
469479
/// For vectors, poison elements are assumed to match.
470480
inline cst_pred_ty<custom_checkfn<APInt>>
471481
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
472-
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
482+
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
473483
}
474484

475-
inline api_pred_ty<custom_checkfn<APInt>>
476-
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
477-
api_pred_ty<custom_checkfn<APInt>> P(V);
478-
P.CheckFn = CheckFn;
479-
return P;
485+
inline cst_pred_ty<custom_checkfn<APInt>>
486+
m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
487+
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
480488
}
481489

482490
/// Match a float or vector where CheckFn(ele) for each element is true.
483491
/// For vectors, poison elements are assumed to match.
484492
inline cstfp_pred_ty<custom_checkfn<APFloat>>
485493
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
486-
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
494+
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
487495
}
488496

489-
inline apf_pred_ty<custom_checkfn<APFloat>>
490-
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
491-
apf_pred_ty<custom_checkfn<APFloat>> P(V);
492-
P.CheckFn = CheckFn;
493-
return P;
497+
inline cstfp_pred_ty<custom_checkfn<APFloat>>
498+
m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
499+
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
494500
}
495501

496502
struct is_any_apint {

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ TEST_F(PatternMatchTest, BitCast) {
614614
TEST_F(PatternMatchTest, CheckedInt) {
615615
Type *I8Ty = IRB.getInt8Ty();
616616
const APInt *Res = nullptr;
617-
617+
const Constant * CRes = nullptr;
618618
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
619619
auto CheckTrue = [](const APInt &) { return true; };
620620
auto CheckFalse = [](const APInt &) { return false; };
@@ -625,39 +625,49 @@ TEST_F(PatternMatchTest, CheckedInt) {
625625
APInt APVal(8, Val);
626626
Constant *C = ConstantInt::get(I8Ty, Val);
627627

628+
CRes = nullptr;
628629
Res = nullptr;
629630
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
630-
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
631+
EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
632+
EXPECT_NE(CRes, nullptr);
633+
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
631634
EXPECT_EQ(*Res, APVal);
632635

636+
CRes = nullptr;
633637
Res = nullptr;
634638
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
635-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
639+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
640+
EXPECT_EQ(CRes, nullptr);
636641

642+
CRes = nullptr;
637643
Res = nullptr;
638644
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
639-
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
645+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
640646
if (CheckUgt1(APVal)) {
641-
EXPECT_NE(Res, nullptr);
647+
EXPECT_NE(CRes, nullptr);
648+
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
642649
EXPECT_EQ(*Res, APVal);
643650
}
644651

652+
CRes = nullptr;
645653
Res = nullptr;
646654
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
647-
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
655+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
648656
if (CheckNonZero(APVal)) {
649-
EXPECT_NE(Res, nullptr);
657+
EXPECT_NE(CRes, nullptr);
658+
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
650659
EXPECT_EQ(*Res, APVal);
651660
}
652661

662+
CRes = nullptr;
653663
Res = nullptr;
654664
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
655-
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
665+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
656666
if (CheckPow2(APVal)) {
657-
EXPECT_NE(Res, nullptr);
667+
EXPECT_NE(CRes, nullptr);
668+
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
658669
EXPECT_EQ(*Res, APVal);
659670
}
660-
661671
};
662672

663673
DoScalarCheck(0);
@@ -666,20 +676,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
666676
DoScalarCheck(3);
667677

668678
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
669-
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
670-
EXPECT_EQ(Res, nullptr);
679+
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
680+
EXPECT_EQ(CRes, nullptr);
671681

672682
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
673-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
674-
EXPECT_EQ(Res, nullptr);
683+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
684+
EXPECT_EQ(CRes, nullptr);
675685

676686
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
677-
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
678-
EXPECT_EQ(Res, nullptr);
687+
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
688+
EXPECT_EQ(CRes, nullptr);
679689

680690
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
681-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
682-
EXPECT_EQ(Res, nullptr);
691+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
692+
EXPECT_EQ(CRes, nullptr);
683693

684694
auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
685695
function_ref<bool(const APInt &)> CheckFn,
@@ -711,13 +721,16 @@ TEST_F(PatternMatchTest, CheckedInt) {
711721
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
712722
m_CheckedInt(CheckFn).match(C));
713723

724+
CRes = nullptr;
714725
Res = nullptr;
715726
bool Expec =
716-
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
717-
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
727+
!(HasUndef && !UndefAsPoison) && Okay.value_or(false);
728+
EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
718729
if (Expec) {
719-
EXPECT_NE(Res, nullptr);
720-
EXPECT_EQ(*Res, *First);
730+
EXPECT_NE(CRes, nullptr);
731+
EXPECT_EQ(match(CRes, m_APIntAllowPoison(Res)), AllSame);
732+
if (AllSame)
733+
EXPECT_EQ(*Res, *First);
721734
}
722735
};
723736
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
@@ -1559,24 +1572,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
15591572
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
15601573

15611574
const APFloat *C;
1575+
const Constant *CC;
15621576
// Regardless of whether poison is allowed,
15631577
// a fully undef/poison constant does not match.
15641578
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
15651579
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
15661580
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
1567-
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
1581+
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
15681582
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
15691583
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
15701584
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
1571-
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
1585+
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
15721586
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
15731587
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
15741588
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
1575-
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
1589+
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
15761590
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
15771591
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
15781592
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
1579-
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
1593+
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));
15801594

15811595
// We can always match simple constants and simple splats.
15821596
C = nullptr;
@@ -1597,12 +1611,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
15971611
C = nullptr;
15981612
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
15991613
EXPECT_TRUE(C->isZero());
1600-
C = nullptr;
1601-
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
1602-
EXPECT_TRUE(C->isZero());
1603-
C = nullptr;
1604-
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
1605-
EXPECT_TRUE(C->isZero());
1614+
1615+
CC = nullptr;
1616+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
1617+
EXPECT_TRUE(CC->isNullValue());
1618+
CC = nullptr;
1619+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
1620+
EXPECT_TRUE(CC->isNullValue());
16061621

16071622
// Splats with undef are never allowed.
16081623
// Whether splats with poison can be matched depends on the matcher.
@@ -1627,11 +1642,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
16271642
C = nullptr;
16281643
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
16291644
EXPECT_TRUE(C->isZero());
1645+
CC = nullptr;
16301646
C = nullptr;
1631-
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
1647+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
1648+
EXPECT_NE(CC, nullptr);
1649+
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
16321650
EXPECT_TRUE(C->isZero());
1651+
CC = nullptr;
16331652
C = nullptr;
1634-
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
1653+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
1654+
EXPECT_NE(CC, nullptr);
1655+
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
16351656
EXPECT_TRUE(C->isZero());
16361657
}
16371658

0 commit comments

Comments
 (0)