Skip to content

Commit fc1b019

Browse files
v01dXYZv01dxyz
andauthored
[DAG] SD Pattern Match: Operands patterns with VP Context (#103308)
Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`). This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context. This PR blocks another PR #102877. Co-authored-by: v01dxyz <[email protected]>
1 parent 8ca5ff2 commit fc1b019

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class BasicMatchContext {
4747
bool match(SDValue N, unsigned Opcode) const {
4848
return N->getOpcode() == Opcode;
4949
}
50+
51+
unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
5052
};
5153

5254
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
390392
template <typename MatchContext>
391393
bool match(const MatchContext &Ctx, SDValue N) {
392394
// Returns false if there are more operands than predicates;
393-
return N->getNumOperands() == OpIdx;
395+
// Ignores the last two operands if both the Context and the Node are VP
396+
return Ctx.getNumOperands(N) == OpIdx;
394397
}
395398
};
396399

@@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
424427
unsigned Size = 0;
425428
unsigned FirstIndex = 0;
426429

427-
explicit EffectiveOperands(SDValue N) {
428-
const unsigned TotalNumOps = N->getNumOperands();
430+
template <typename MatchContext>
431+
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
432+
const unsigned TotalNumOps = Ctx.getNumOperands(N);
429433
FirstIndex = TotalNumOps;
430434
for (unsigned I = 0; I < TotalNumOps; ++I) {
431435
// Count the number of non-chain and non-glue nodes (we ignore chain
@@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
444448
unsigned Size = 0;
445449
unsigned FirstIndex = 0;
446450

447-
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
451+
template <typename MatchContext>
452+
explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
453+
: Size(Ctx.getNumOperands(N)) {}
448454
};
449455

450456
// === Ternary operations ===
@@ -463,7 +469,7 @@ struct TernaryOpc_match {
463469
template <typename MatchContext>
464470
bool match(const MatchContext &Ctx, SDValue N) {
465471
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
466-
EffectiveOperands<ExcludeChain> EO(N);
472+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
467473
assert(EO.Size == 3);
468474
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
469475
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -515,7 +521,7 @@ struct BinaryOpc_match {
515521
template <typename MatchContext>
516522
bool match(const MatchContext &Ctx, SDValue N) {
517523
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
518-
EffectiveOperands<ExcludeChain> EO(N);
524+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
519525
assert(EO.Size == 2);
520526
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
521527
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
667673
template <typename MatchContext>
668674
bool match(const MatchContext &Ctx, SDValue N) {
669675
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
670-
EffectiveOperands<ExcludeChain> EO(N);
676+
EffectiveOperands<ExcludeChain> EO(N, Ctx);
671677
assert(EO.Size == 1);
672678
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
673679
}

llvm/lib/CodeGen/SelectionDAG/MatchContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class EmptyMatchContext {
4545
bool LegalOnly = false) const {
4646
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
4747
}
48+
49+
unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
4850
};
4951

5052
class VPMatchContext {
@@ -169,6 +171,10 @@ class VPMatchContext {
169171
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
170172
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
171173
}
174+
175+
unsigned getNumOperands(SDValue N) const {
176+
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
177+
}
172178
};
173179

174180
} // namespace llvm

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
393393
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
394394
return BaseOpc.has_value() && *BaseOpc == Opc;
395395
}
396+
397+
unsigned getNumOperands(SDValue N) const {
398+
return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
399+
}
396400
};
397401
} // anonymous namespace
398402
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -410,15 +414,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
410414
{Vector0, Vector0, Mask0, Scalar0});
411415
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
412416
{Scalar0, VPAdd, Mask0, Scalar0});
417+
SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
413418

414419
using namespace SDPatternMatch;
415420
VPMatchContext VPCtx(DAG.get());
416421
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
422+
EXPECT_TRUE(
423+
sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
424+
// VPMatchContext can't match pattern using explicit VP Opcode
425+
EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
426+
m_Node(ISD::VP_ADD, m_Value(), m_Value())));
427+
EXPECT_FALSE(sd_context_match(
428+
VPAdd, VPCtx,
429+
m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
430+
// Check Binary Op Pattern
431+
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
417432
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
418433
// sd_match before switching to VPMatchContext when checking VPAdd.
419434
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
420435
m_Context(VPCtx, m_Opc(ISD::ADD)),
421436
m_Value(), m_Value())));
437+
// non-vector predicated should match too
438+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
439+
EXPECT_TRUE(
440+
sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
441+
EXPECT_FALSE(sd_context_match(
442+
Add, VPCtx,
443+
m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
444+
EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
445+
}
446+
447+
TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
448+
SDLoc DL;
449+
auto BoolVT = EVT::getIntegerVT(Context, 1);
450+
auto Int32VT = EVT::getIntegerVT(Context, 32);
451+
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
452+
auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
453+
454+
SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
455+
SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
456+
SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
457+
458+
SDValue VPAdd =
459+
DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
460+
461+
using namespace SDPatternMatch;
462+
EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
463+
EXPECT_TRUE(sd_match(
464+
VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
422465
}
423466

424467
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {

0 commit comments

Comments
 (0)