Skip to content

Commit 11cb3c3

Browse files
committed
[IR][PatternMatch] Make m_Checked{Int,Fp} accept Constant * output instead of APInt *
The `APInt *` version is pretty useless as any case one needs an `APInt *` out, they could just replace whatever they have the `m_Checked...` lambda with direct checks on the `APInt`. Leaving other helpers such as `m_Negative`, `m_Power2`, etc... unchanged as the `APInt` out version is used mostly for convenience and rarely change functionality when converted output a `Constant *`. Closes #91377
1 parent 38b2755 commit 11cb3c3

File tree

2 files changed

+72
-64
lines changed

2 files changed

+72
-64
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
354354
/// is true.
355355
template <typename Predicate, typename ConstantVal, bool AllowPoison>
356356
struct cstval_pred_ty : public Predicate {
357-
template <typename ITy> bool match(ITy *V) {
357+
const Constant **Res = nullptr;
358+
template <typename ITy> bool match_impl(ITy *V) {
358359
if (const auto *CV = dyn_cast<ConstantVal>(V))
359360
return this->isValue(CV->getValue());
360361
if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
@@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
387388
}
388389
return false;
389390
}
391+
392+
template <typename ITy> bool match(ITy *V) {
393+
if (this->match_impl(V)) {
394+
if (Res)
395+
*Res = cast<Constant>(V);
396+
return true;
397+
}
398+
return false;
399+
}
390400
};
391401

392402
/// specialization of cstval_pred_ty for ConstantInt
@@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
469479
/// For vectors, poison elements are assumed to match.
470480
inline cst_pred_ty<custom_checkfn<APInt>>
471481
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
472-
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
482+
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
473483
}
474484

475-
inline api_pred_ty<custom_checkfn<APInt>>
476-
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
477-
api_pred_ty<custom_checkfn<APInt>> P(V);
478-
P.CheckFn = CheckFn;
479-
return P;
485+
inline cst_pred_ty<custom_checkfn<APInt>>
486+
m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
487+
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
480488
}
481489

482490
/// Match a float or vector where CheckFn(ele) for each element is true.
483491
/// For vectors, poison elements are assumed to match.
484492
inline cstfp_pred_ty<custom_checkfn<APFloat>>
485493
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
486-
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
494+
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
487495
}
488496

489-
inline apf_pred_ty<custom_checkfn<APFloat>>
490-
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
491-
apf_pred_ty<custom_checkfn<APFloat>> P(V);
492-
P.CheckFn = CheckFn;
493-
return P;
497+
inline cstfp_pred_ty<custom_checkfn<APFloat>>
498+
m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
499+
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
494500
}
495501

496502
struct is_any_apint {

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,7 @@ TEST_F(PatternMatchTest, BitCast) {
613613

614614
TEST_F(PatternMatchTest, CheckedInt) {
615615
Type *I8Ty = IRB.getInt8Ty();
616-
const APInt *Res = nullptr;
617-
616+
const Constant * CRes = nullptr;
618617
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
619618
auto CheckTrue = [](const APInt &) { return true; };
620619
auto CheckFalse = [](const APInt &) { return false; };
@@ -625,38 +624,33 @@ TEST_F(PatternMatchTest, CheckedInt) {
625624
APInt APVal(8, Val);
626625
Constant *C = ConstantInt::get(I8Ty, Val);
627626

628-
Res = nullptr;
627+
CRes = nullptr;
629628
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
630-
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
631-
EXPECT_EQ(*Res, APVal);
629+
EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
630+
EXPECT_EQ(CRes, C);
632631

633-
Res = nullptr;
632+
CRes = nullptr;
634633
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
635-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
634+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
635+
EXPECT_EQ(CRes, nullptr);
636636

637-
Res = nullptr;
637+
CRes = nullptr;
638638
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
639-
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
640-
if (CheckUgt1(APVal)) {
641-
EXPECT_NE(Res, nullptr);
642-
EXPECT_EQ(*Res, APVal);
643-
}
639+
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
640+
if (CheckUgt1(APVal))
641+
EXPECT_EQ(CRes, C);
644642

645-
Res = nullptr;
643+
CRes = nullptr;
646644
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
647-
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
648-
if (CheckNonZero(APVal)) {
649-
EXPECT_NE(Res, nullptr);
650-
EXPECT_EQ(*Res, APVal);
651-
}
645+
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
646+
if (CheckNonZero(APVal))
647+
EXPECT_EQ(CRes, C);
652648

653-
Res = nullptr;
649+
CRes = nullptr;
654650
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
655-
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
656-
if (CheckPow2(APVal)) {
657-
EXPECT_NE(Res, nullptr);
658-
EXPECT_EQ(*Res, APVal);
659-
}
651+
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
652+
if (CheckPow2(APVal))
653+
EXPECT_EQ(CRes, C);
660654

661655
};
662656

@@ -666,20 +660,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
666660
DoScalarCheck(3);
667661

668662
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
669-
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
670-
EXPECT_EQ(Res, nullptr);
663+
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
664+
EXPECT_EQ(CRes, nullptr);
671665

672666
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
673-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
674-
EXPECT_EQ(Res, nullptr);
667+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
668+
EXPECT_EQ(CRes, nullptr);
675669

676670
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
677-
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
678-
EXPECT_EQ(Res, nullptr);
671+
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
672+
EXPECT_EQ(CRes, nullptr);
679673

680674
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
681-
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
682-
EXPECT_EQ(Res, nullptr);
675+
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
676+
EXPECT_EQ(CRes, nullptr);
683677

684678
auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
685679
function_ref<bool(const APInt &)> CheckFn,
@@ -711,13 +705,13 @@ TEST_F(PatternMatchTest, CheckedInt) {
711705
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
712706
m_CheckedInt(CheckFn).match(C));
713707

714-
Res = nullptr;
715-
bool Expec =
716-
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
717-
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
708+
CRes = nullptr;
709+
bool Expec = !(HasUndef && !UndefAsPoison) && Okay.value_or(false);
710+
EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
718711
if (Expec) {
719-
EXPECT_NE(Res, nullptr);
720-
EXPECT_EQ(*Res, *First);
712+
EXPECT_NE(CRes, nullptr);
713+
if (AllSame)
714+
EXPECT_EQ(CRes, C);
721715
}
722716
};
723717
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
@@ -1559,24 +1553,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
15591553
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
15601554

15611555
const APFloat *C;
1556+
const Constant *CC;
15621557
// Regardless of whether poison is allowed,
15631558
// a fully undef/poison constant does not match.
15641559
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
15651560
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
15661561
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
1567-
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
1562+
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
15681563
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
15691564
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
15701565
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
1571-
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
1566+
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
15721567
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
15731568
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
15741569
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
1575-
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
1570+
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
15761571
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
15771572
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
15781573
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
1579-
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
1574+
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));
15801575

15811576
// We can always match simple constants and simple splats.
15821577
C = nullptr;
@@ -1597,12 +1592,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
15971592
C = nullptr;
15981593
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
15991594
EXPECT_TRUE(C->isZero());
1600-
C = nullptr;
1601-
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
1602-
EXPECT_TRUE(C->isZero());
1603-
C = nullptr;
1604-
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
1605-
EXPECT_TRUE(C->isZero());
1595+
1596+
CC = nullptr;
1597+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
1598+
EXPECT_TRUE(CC->isNullValue());
1599+
CC = nullptr;
1600+
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
1601+
EXPECT_TRUE(CC->isNullValue());
16061602

16071603
// Splats with undef are never allowed.
16081604
// Whether splats with poison can be matched depends on the matcher.
@@ -1627,11 +1623,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
16271623
C = nullptr;
16281624
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
16291625
EXPECT_TRUE(C->isZero());
1626+
CC = nullptr;
16301627
C = nullptr;
1631-
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
1628+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
1629+
EXPECT_NE(CC, nullptr);
1630+
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
16321631
EXPECT_TRUE(C->isZero());
1632+
CC = nullptr;
16331633
C = nullptr;
1634-
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
1634+
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
1635+
EXPECT_NE(CC, nullptr);
1636+
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
16351637
EXPECT_TRUE(C->isZero());
16361638
}
16371639

0 commit comments

Comments
 (0)