Skip to content

Commit 771baaf

Browse files
author
v01dxyz
committed
[SDPatternMatch][VP] Operands pattern: Support VPMatchContext
Ignore last two operands if Node is VP
1 parent 4ac42af commit 771baaf

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class BasicMatchContext {
4747
bool match(SDValue N, unsigned Opcode) const {
4848
return N->getOpcode() == Opcode;
4949
}
50+
51+
unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
5052
};
5153

5254
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
390392
template <typename MatchContext>
391393
bool match(const MatchContext &Ctx, SDValue N) {
392394
// Returns false if there are more operands than predicates;
393-
return N->getNumOperands() == OpIdx;
395+
// Ignores the last two operands if both the Context and the Node are VP
396+
return Ctx.getNumOperands(N) == OpIdx;
394397
}
395398
};
396399

@@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
424427
unsigned Size = 0;
425428
unsigned FirstIndex = 0;
426429

427-
explicit EffectiveOperands(SDValue N) {
428-
const unsigned TotalNumOps = N->getNumOperands();
430+
template <typename MatchContext>
431+
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
432+
const unsigned TotalNumOps = Ctx.getNumOperands(N);
429433
FirstIndex = TotalNumOps;
430434
for (unsigned I = 0; I < TotalNumOps; ++I) {
431435
// Count the number of non-chain and non-glue nodes (we ignore chain
@@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
444448
unsigned Size = 0;
445449
unsigned FirstIndex = 0;
446450

447-
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
451+
template <typename MatchContext>
452+
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
453+
: Size(Ctx.getNumOperands(N)) {}
448454
};
449455

450456
// === Ternary operations ===
@@ -463,7 +469,7 @@ struct TernaryOpc_match {
463469
template <typename MatchContext>
464470
bool match(const MatchContext &Ctx, SDValue N) {
465471
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
466-
EffectiveOperands<ExcludeChain> EO(N);
472+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
467473
assert(EO.Size == 3);
468474
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
469475
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -515,7 +521,7 @@ struct BinaryOpc_match {
515521
template <typename MatchContext>
516522
bool match(const MatchContext &Ctx, SDValue N) {
517523
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
518-
EffectiveOperands<ExcludeChain> EO(N);
524+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
519525
assert(EO.Size == 2);
520526
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
521527
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
667673
template <typename MatchContext>
668674
bool match(const MatchContext &Ctx, SDValue N) {
669675
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
670-
EffectiveOperands<ExcludeChain> EO(N);
676+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
671677
assert(EO.Size == 1);
672678
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
673679
}

llvm/lib/CodeGen/SelectionDAG/MatchContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class EmptyMatchContext {
4646
bool LegalOnly = false) const {
4747
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
4848
}
49+
50+
unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
4951
};
5052

5153
class VPMatchContext {
@@ -170,6 +172,10 @@ class VPMatchContext {
170172
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
171173
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
172174
}
175+
176+
unsigned getNumOperands(SDValue N) const {
177+
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
178+
}
173179
};
174180
} // end anonymous namespace
175181
#endif

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
383383
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
384384
return BaseOpc.has_value() && *BaseOpc == Opc;
385385
}
386+
387+
unsigned getNumOperands(SDValue N) const {
388+
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
389+
}
386390
};
387391
} // anonymous namespace
388392
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +404,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
400404
{Vector0, Vector0, Mask0, Scalar0});
401405
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
402406
{Scalar0, VPAdd, Mask0, Scalar0});
407+
SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
403408

404409
using namespace SDPatternMatch;
405410
VPMatchContext VPCtx(DAG.get());
406411
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
412+
EXPECT_TRUE(
413+
sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
414+
// VPMatchContext can't match pattern using explicit VP Opcode
415+
EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
416+
m_Node(ISD::VP_ADD, m_Value(), m_Value())));
417+
EXPECT_FALSE(sd_context_match(
418+
VPAdd, VPCtx,
419+
m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
420+
// Check Binary Op Pattern
421+
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
407422
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
408423
// sd_match before switching to VPMatchContext when checking VPAdd.
409424
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
410425
m_Context(VPCtx, m_Opc(ISD::ADD)),
411426
m_Value(), m_Value())));
427+
// non-vector predicated should match too
428+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
429+
EXPECT_TRUE(
430+
sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
431+
EXPECT_FALSE(sd_context_match(
432+
Add, VPCtx,
433+
m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
434+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
435+
}
436+
437+
TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
438+
SDLoc DL;
439+
auto BoolVT = EVT::getIntegerVT(Context, 1);
440+
auto Int32VT = EVT::getIntegerVT(Context, 32);
441+
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
442+
auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
443+
444+
SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
445+
SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
446+
SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
447+
448+
SDValue VPAdd =
449+
DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
450+
451+
using namespace SDPatternMatch;
452+
EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
453+
EXPECT_TRUE(sd_match(
454+
VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
412455
}
413456

414457
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {

0 commit comments

Comments
 (0)