Skip to content

Commit b520126

Browse files
author
v01dxyz
committed
[SDPatternMatch][VP] Operands pattern: Support VPMatchContext
Ignore last two operands if Node is VP
1 parent 4ac42af commit b520126

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class BasicMatchContext {
4747
bool match(SDValue N, unsigned Opcode) const {
4848
return N->getOpcode() == Opcode;
4949
}
50+
51+
static constexpr bool IsVP = false;
5052
};
5153

5254
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
390392
template <typename MatchContext>
391393
bool match(const MatchContext &Ctx, SDValue N) {
392394
// Returns false if there are more operands than predicates;
393-
return N->getNumOperands() == OpIdx;
395+
// Ignores the last two operands if both the Context and the Node are VP
396+
return N->getNumOperands() == (OpIdx + 2 * Ctx.IsVP * N->isVPOpcode());
394397
}
395398
};
396399

@@ -464,7 +467,7 @@ struct TernaryOpc_match {
464467
bool match(const MatchContext &Ctx, SDValue N) {
465468
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
466469
EffectiveOperands<ExcludeChain> EO(N);
467-
assert(EO.Size == 3);
470+
assert(EO.Size == 3U + 2 * N->isVPOpcode());
468471
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
469472
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
470473
(Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -516,7 +519,7 @@ struct BinaryOpc_match {
516519
bool match(const MatchContext &Ctx, SDValue N) {
517520
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
518521
EffectiveOperands<ExcludeChain> EO(N);
519-
assert(EO.Size == 2);
522+
assert(EO.Size == 2U + 2 * N->isVPOpcode());
520523
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
521524
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
522525
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -668,7 +671,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
668671
bool match(const MatchContext &Ctx, SDValue N) {
669672
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
670673
EffectiveOperands<ExcludeChain> EO(N);
671-
assert(EO.Size == 1);
674+
assert(EO.Size == 1U + 2 * N->isVPOpcode());
672675
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
673676
}
674677

llvm/lib/CodeGen/SelectionDAG/MatchContext.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class EmptyMatchContext {
4646
bool LegalOnly = false) const {
4747
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
4848
}
49+
50+
static constexpr bool IsVP = false;
4951
};
5052

5153
class VPMatchContext {
@@ -170,6 +172,8 @@ class VPMatchContext {
170172
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
171173
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
172174
}
175+
176+
static constexpr bool IsVP = true;
173177
};
174178
} // end anonymous namespace
175179
#endif

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
383383
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
384384
return BaseOpc.has_value() && *BaseOpc == Opc;
385385
}
386+
387+
static constexpr bool IsVP = true;
386388
};
387389
} // anonymous namespace
388390
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +402,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
400402
{Vector0, Vector0, Mask0, Scalar0});
401403
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
402404
{Scalar0, VPAdd, Mask0, Scalar0});
405+
SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
403406

404407
using namespace SDPatternMatch;
405408
VPMatchContext VPCtx(DAG.get());
406409
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
410+
EXPECT_TRUE(
411+
sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
412+
// VPMatchContext can't match pattern using explicit VP Opcode
413+
EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
414+
m_Node(ISD::VP_ADD, m_Value(), m_Value())));
415+
EXPECT_FALSE(sd_context_match(
416+
VPAdd, VPCtx,
417+
m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
418+
// Check Binary Op Pattern
419+
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
407420
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
408421
// sd_match before switching to VPMatchContext when checking VPAdd.
409422
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
410423
m_Context(VPCtx, m_Opc(ISD::ADD)),
411424
m_Value(), m_Value())));
425+
// non-vector predicated should match too
426+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
427+
EXPECT_TRUE(
428+
sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
429+
EXPECT_FALSE(sd_context_match(
430+
Add, VPCtx,
431+
m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
432+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
433+
}
434+
435+
TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
436+
SDLoc DL;
437+
auto BoolVT = EVT::getIntegerVT(Context, 1);
438+
auto Int32VT = EVT::getIntegerVT(Context, 32);
439+
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
440+
auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
441+
442+
SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
443+
SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
444+
SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
445+
446+
SDValue VPAdd =
447+
DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
448+
449+
using namespace SDPatternMatch;
450+
EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
451+
EXPECT_TRUE(sd_match(
452+
VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
412453
}
413454

414455
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {

0 commit comments

Comments
 (0)