Skip to content

SD Pattern Match: Operands patterns with VP Context #103308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class BasicMatchContext {
bool match(SDValue N, unsigned Opcode) const {
return N->getOpcode() == Opcode;
}

unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};

template <typename Pattern, typename MatchContext>
Expand Down Expand Up @@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
// Returns false if there are more operands than predicates;
return N->getNumOperands() == OpIdx;
// Ignores the last two operands if both the Context and the Node are VP
return Ctx.getNumOperands(N) == OpIdx;
}
};

Expand Down Expand Up @@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
unsigned Size = 0;
unsigned FirstIndex = 0;

explicit EffectiveOperands(SDValue N) {
const unsigned TotalNumOps = N->getNumOperands();
template <typename MatchContext>
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
const unsigned TotalNumOps = Ctx.getNumOperands(N);
FirstIndex = TotalNumOps;
for (unsigned I = 0; I < TotalNumOps; ++I) {
// Count the number of non-chain and non-glue nodes (we ignore chain
Expand All @@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
unsigned Size = 0;
unsigned FirstIndex = 0;

explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
template <typename MatchContext>
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
: Size(Ctx.getNumOperands(N)) {}
};

// === Ternary operations ===
Expand All @@ -463,7 +469,7 @@ struct TernaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 3);
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
Expand Down Expand Up @@ -515,7 +521,7 @@ struct BinaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 2);
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
Expand Down Expand Up @@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 1);
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
}
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/MatchContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class EmptyMatchContext {
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
}

unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};

class VPMatchContext {
Expand Down Expand Up @@ -170,6 +172,10 @@ class VPMatchContext {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
}

unsigned getNumOperands(SDValue N) const {
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
}
};
} // end anonymous namespace
#endif
43 changes: 43 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
return BaseOpc.has_value() && *BaseOpc == Opc;
}

unsigned getNumOperands(SDValue N) const {
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
}
};
} // anonymous namespace
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
Expand All @@ -400,15 +404,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
{Vector0, Vector0, Mask0, Scalar0});
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
{Scalar0, VPAdd, Mask0, Scalar0});
SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});

using namespace SDPatternMatch;
VPMatchContext VPCtx(DAG.get());
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
EXPECT_TRUE(
sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
// VPMatchContext can't match pattern using explicit VP Opcode
EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
m_Node(ISD::VP_ADD, m_Value(), m_Value())));
EXPECT_FALSE(sd_context_match(
VPAdd, VPCtx,
m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
// Check Binary Op Pattern
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
// sd_match before switching to VPMatchContext when checking VPAdd.
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
m_Context(VPCtx, m_Opc(ISD::ADD)),
m_Value(), m_Value())));
// non-vector predicated should match too
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
EXPECT_TRUE(
sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
EXPECT_FALSE(sd_context_match(
Add, VPCtx,
m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
}

TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
SDLoc DL;
auto BoolVT = EVT::getIntegerVT(Context, 1);
auto Int32VT = EVT::getIntegerVT(Context, 32);
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);

SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);

SDValue VPAdd =
DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);

using namespace SDPatternMatch;
EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
}

TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
Expand Down
Loading