Skip to content

Commit 70f3863

Browse files
committed
[DAG][PatternMatch] Add support for matchers with flags; NFC
Add support for matching with `SDNodeFlags` i.e `add` with `nuw`. This patch adds helpers for `or disjoint` or `zext nneg` with the same names as we have in IR/PatternMatch api. Closes #103060
1 parent 0224d83 commit 70f3863

File tree

3 files changed

+90
-11
lines changed

3 files changed

+90
-11
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -514,19 +514,28 @@ struct BinaryOpc_match {
514514
unsigned Opcode;
515515
LHS_P LHS;
516516
RHS_P RHS;
517-
518-
BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
519-
: Opcode(Opc), LHS(L), RHS(R) {}
517+
std::optional<SDNodeFlags> Flags;
518+
BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
519+
std::optional<SDNodeFlags> Flgs = std::nullopt)
520+
: Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
520521

521522
template <typename MatchContext>
522523
bool match(const MatchContext &Ctx, SDValue N) {
523524
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
524525
EffectiveOperands<ExcludeChain> EO(N, Ctx);
525526
assert(EO.Size == 2);
526-
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
527-
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
528-
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
529-
RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
527+
if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
528+
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
529+
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
530+
RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
531+
return false;
532+
533+
if (!Flags.has_value())
534+
return true;
535+
536+
SDNodeFlags TmpFlags = *Flags;
537+
TmpFlags.intersectWith(N->getFlags());
538+
return TmpFlags == *Flags;
530539
}
531540

532541
return false;
@@ -581,6 +590,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
581590
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
582591
}
583592

593+
template <typename LHS, typename RHS>
594+
inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
595+
const RHS &R) {
596+
SDNodeFlags Flags;
597+
Flags.setDisjoint(true);
598+
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
599+
}
600+
601+
template <typename LHS, typename RHS>
602+
inline auto m_AddLike(const LHS &L, const RHS &R) {
603+
return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
604+
}
605+
584606
template <typename LHS, typename RHS>
585607
inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
586608
return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
@@ -667,15 +689,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
667689
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
668690
unsigned Opcode;
669691
Opnd_P Opnd;
670-
671-
UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
692+
std::optional<SDNodeFlags> Flags;
693+
UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
694+
std::optional<SDNodeFlags> Flgs = std::nullopt)
695+
: Opcode(Opc), Opnd(Op), Flags(Flgs) {}
672696

673697
template <typename MatchContext>
674698
bool match(const MatchContext &Ctx, SDValue N) {
675699
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
676700
EffectiveOperands<ExcludeChain> EO(N, Ctx);
677701
assert(EO.Size == 1);
678-
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
702+
if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
703+
return false;
704+
if (!Flags.has_value())
705+
return true;
706+
707+
SDNodeFlags TmpFlags = *Flags;
708+
TmpFlags.intersectWith(N->getFlags());
709+
return TmpFlags == *Flags;
679710
}
680711

681712
return false;
@@ -701,6 +732,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
701732
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
702733
}
703734

735+
template <typename Opnd>
736+
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
737+
SDNodeFlags Flags;
738+
Flags.setNonNeg(true);
739+
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
740+
}
741+
704742
template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
705743
return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
706744
}
@@ -725,6 +763,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
725763
return m_AnyOf(m_SExt(Op), Op);
726764
}
727765

766+
template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
767+
return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
768+
}
769+
728770
/// Match a aext or identity
729771
/// Allows to peek through optional extensions
730772
template <typename Opnd>

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,20 @@ struct SDNodeFlags {
452452
bool hasNoFPExcept() const { return NoFPExcept; }
453453
bool hasUnpredictable() const { return Unpredictable; }
454454

455+
bool operator==(const SDNodeFlags &Other) const {
456+
return NoUnsignedWrap == Other.NoUnsignedWrap &&
457+
NoSignedWrap == Other.NoSignedWrap && Exact == Other.Exact &&
458+
Disjoint == Other.Disjoint && NonNeg == Other.NonNeg &&
459+
NoNaNs == Other.NoNaNs && NoInfs == Other.NoInfs &&
460+
NoSignedZeros == Other.NoSignedZeros &&
461+
AllowReciprocal == Other.AllowReciprocal &&
462+
AllowContract == Other.AllowContract &&
463+
ApproximateFuncs == Other.ApproximateFuncs &&
464+
AllowReassociation == Other.AllowReassociation &&
465+
NoFPExcept == Other.NoFPExcept &&
466+
Unpredictable == Other.Unpredictable;
467+
}
468+
455469
/// Clear any flags in this flag set that aren't also set in Flags. All
456470
/// flags will be cleared if Flags are undefined.
457471
void intersectWith(const SDNodeFlags Flags) {

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,17 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
185185
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
186186
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
187187
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
188+
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
188189

189190
SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
190191
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
191192
SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
192193
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
193194
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
194195
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
196+
SDNodeFlags DisFlags;
197+
DisFlags.setDisjoint(true);
198+
SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
195199
SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
196200
SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
197201
SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
@@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
205209
EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
206210
EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
207211
EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
212+
EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
208213
EXPECT_TRUE(sd_match(
209214
Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
210215
EXPECT_TRUE(
@@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
217222
EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
218223
EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
219224
EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
225+
EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
226+
227+
EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
228+
EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
229+
EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
230+
EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
220231

221232
EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
222233
EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
@@ -242,9 +253,14 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
242253

243254
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
244255
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
245-
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT);
256+
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT);
257+
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
246258

247259
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
260+
SDNodeFlags NNegFlags;
261+
NNegFlags.setNonNeg(true);
262+
SDValue ZExtNNeg =
263+
DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op3, NNegFlags);
248264
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
249265
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
250266

@@ -260,6 +276,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
260276
using namespace SDPatternMatch;
261277
EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
262278
EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
279+
EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
280+
ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
281+
EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
282+
EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
283+
EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
284+
EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
285+
EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
263286
EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
264287

265288
EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));

0 commit comments

Comments
 (0)