Skip to content

Commit 6b00ae6

Browse files
Esan5mshockwave
andauthored
[DAG] SDPatternMatch - add matchers for reassociatable binops (#119985)
fixes #118847 implements matchers for reassociatable opcodes as well as helpers for commonly used reassociatable binary matchers. --------- Co-authored-by: Min-Yih Hsu <[email protected]>
1 parent 052a4b5 commit 6b00ae6

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
#define LLVM_CODEGEN_SDPATTERNMATCH_H
1515

1616
#include "llvm/ADT/APInt.h"
17+
#include "llvm/ADT/ArrayRef.h"
1718
#include "llvm/ADT/STLExtras.h"
19+
#include "llvm/ADT/SmallBitVector.h"
1820
#include "llvm/CodeGen/SelectionDAG.h"
1921
#include "llvm/CodeGen/SelectionDAGNodes.h"
2022
#include "llvm/CodeGen/TargetLowering.h"
@@ -1134,6 +1136,87 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
11341136
return m_Xor(V, m_AllOnes());
11351137
}
11361138

1139+
template <typename... PatternTs> struct ReassociatableOpc_match {
1140+
unsigned Opcode;
1141+
std::tuple<PatternTs...> Patterns;
1142+
1143+
ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
1144+
: Opcode(Opcode), Patterns(Patterns...) {}
1145+
1146+
template <typename MatchContext>
1147+
bool match(const MatchContext &Ctx, SDValue N) {
1148+
SmallVector<SDValue> Leaves;
1149+
collectLeaves(N, Leaves);
1150+
if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>)
1151+
return false;
1152+
1153+
// Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
1154+
// std::get<J>(Patterns)) == true
1155+
std::array<SmallBitVector, std::tuple_size_v<std::tuple<PatternTs...>>>
1156+
Matches;
1157+
for (size_t I = 0, N = Leaves.size(); I < N; I++) {
1158+
SmallVector<bool> MatchResults;
1159+
std::apply(
1160+
[&](auto &...P) {
1161+
(Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
1162+
},
1163+
Patterns);
1164+
}
1165+
1166+
SmallBitVector Used(std::tuple_size_v<std::tuple<PatternTs...>>);
1167+
return reassociatableMatchHelper(Matches, Used);
1168+
}
1169+
1170+
void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
1171+
if (V->getOpcode() == Opcode) {
1172+
for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
1173+
collectLeaves(V->getOperand(I), Leaves);
1174+
} else {
1175+
Leaves.emplace_back(V);
1176+
}
1177+
}
1178+
1179+
[[nodiscard]] inline bool
1180+
reassociatableMatchHelper(const ArrayRef<SmallBitVector> Matches,
1181+
SmallBitVector &Used, size_t Curr = 0) {
1182+
if (Curr == Matches.size())
1183+
return true;
1184+
for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
1185+
if (!Matches[Curr][Match] || Used[Match])
1186+
continue;
1187+
Used[Match] = true;
1188+
if (reassociatableMatchHelper(Matches, Used, Curr + 1))
1189+
return true;
1190+
Used[Match] = false;
1191+
}
1192+
return false;
1193+
}
1194+
};
1195+
1196+
template <typename... PatternTs>
1197+
inline ReassociatableOpc_match<PatternTs...>
1198+
m_ReassociatableAdd(const PatternTs &...Patterns) {
1199+
return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
1200+
}
1201+
1202+
template <typename... PatternTs>
1203+
inline ReassociatableOpc_match<PatternTs...>
1204+
m_ReassociatableOr(const PatternTs &...Patterns) {
1205+
return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
1206+
}
1207+
1208+
template <typename... PatternTs>
1209+
inline ReassociatableOpc_match<PatternTs...>
1210+
m_ReassociatableAnd(const PatternTs &...Patterns) {
1211+
return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
1212+
}
1213+
1214+
template <typename... PatternTs>
1215+
inline ReassociatableOpc_match<PatternTs...>
1216+
m_ReassociatableMul(const PatternTs &...Patterns) {
1217+
return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
1218+
}
1219+
11371220
} // namespace SDPatternMatch
11381221
} // namespace llvm
11391222
#endif

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,3 +651,128 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
651651
EXPECT_TRUE(sd_match(Add, DAG.get(),
652652
m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
653653
}
654+
655+
TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
656+
using namespace SDPatternMatch;
657+
658+
SDLoc DL;
659+
auto Int32VT = EVT::getIntegerVT(Context, 32);
660+
661+
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
662+
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
663+
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
664+
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
665+
666+
// (Op0 + Op1) + (Op2 + Op3)
667+
SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
668+
SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
669+
SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
670+
671+
EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value())));
672+
EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
673+
EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
674+
EXPECT_TRUE(sd_match(
675+
ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
676+
677+
// Op0 + (Op1 + (Op2 + Op3))
678+
SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
679+
SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
680+
EXPECT_TRUE(
681+
sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
682+
EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
683+
m_Value(), m_Value())));
684+
685+
// (Op0 - Op1) + (Op2 - Op3)
686+
SDValue SUB01 = DAG->getNode(ISD::SUB, DL, Int32VT, Op0, Op1);
687+
SDValue SUB23 = DAG->getNode(ISD::SUB, DL, Int32VT, Op2, Op3);
688+
SDValue ADDS0123 = DAG->getNode(ISD::ADD, DL, Int32VT, SUB01, SUB23);
689+
690+
EXPECT_FALSE(sd_match(SUB01, m_ReassociatableAdd(m_Value(), m_Value())));
691+
EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
692+
m_Value(), m_Value())));
693+
694+
// SUB + SUB matches (Op0 - Op1) + (Op2 - Op3)
695+
EXPECT_TRUE(
696+
sd_match(ADDS0123, m_ReassociatableAdd(m_Sub(m_Value(), m_Value()),
697+
m_Sub(m_Value(), m_Value()))));
698+
EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
699+
m_Value(), m_Value())));
700+
701+
// (Op0 * Op1) * (Op2 * Op3)
702+
SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
703+
SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
704+
SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
705+
706+
EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
707+
EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
708+
EXPECT_TRUE(sd_match(
709+
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
710+
711+
// Op0 * (Op1 * (Op2 * Op3))
712+
SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
713+
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
714+
EXPECT_TRUE(
715+
sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
716+
EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
717+
m_Value(), m_Value())));
718+
719+
// (Op0 - Op1) * (Op2 - Op3)
720+
SDValue MULS0123 = DAG->getNode(ISD::MUL, DL, Int32VT, SUB01, SUB23);
721+
EXPECT_TRUE(
722+
sd_match(MULS0123, m_ReassociatableMul(m_Sub(m_Value(), m_Value()),
723+
m_Sub(m_Value(), m_Value()))));
724+
EXPECT_FALSE(sd_match(MULS0123, m_ReassociatableMul(m_Value(), m_Value(),
725+
m_Value(), m_Value())));
726+
727+
// (Op0 && Op1) && (Op2 && Op3)
728+
SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
729+
SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
730+
SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
731+
732+
EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
733+
EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
734+
EXPECT_TRUE(sd_match(
735+
AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
736+
737+
// Op0 && (Op1 && (Op2 && Op3))
738+
SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
739+
SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
740+
EXPECT_TRUE(
741+
sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
742+
EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
743+
m_Value(), m_Value())));
744+
745+
// (Op0 - Op1) && (Op2 - Op3)
746+
SDValue ANDS0123 = DAG->getNode(ISD::AND, DL, Int32VT, SUB01, SUB23);
747+
EXPECT_TRUE(
748+
sd_match(ANDS0123, m_ReassociatableAnd(m_Sub(m_Value(), m_Value()),
749+
m_Sub(m_Value(), m_Value()))));
750+
EXPECT_FALSE(sd_match(ANDS0123, m_ReassociatableAnd(m_Value(), m_Value(),
751+
m_Value(), m_Value())));
752+
753+
// (Op0 || Op1) || (Op2 || Op3)
754+
SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
755+
SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
756+
SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
757+
758+
EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
759+
EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
760+
EXPECT_TRUE(sd_match(
761+
OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
762+
763+
// Op0 || (Op1 || (Op2 || Op3))
764+
SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
765+
SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
766+
EXPECT_TRUE(
767+
sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
768+
EXPECT_TRUE(sd_match(
769+
OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
770+
771+
// (Op0 - Op1) || (Op2 - Op3)
772+
SDValue ORS0123 = DAG->getNode(ISD::OR, DL, Int32VT, SUB01, SUB23);
773+
EXPECT_TRUE(
774+
sd_match(ORS0123, m_ReassociatableOr(m_Sub(m_Value(), m_Value()),
775+
m_Sub(m_Value(), m_Value()))));
776+
EXPECT_FALSE(sd_match(
777+
ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
778+
}

0 commit comments

Comments
 (0)