Skip to content

[DAG] Add SDPatternMatch::m_SetCC and update some combines to use it #98646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
};

// === Ternary operations ===
template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
bool ExcludeChain = false>
struct TernaryOpc_match {
unsigned Opcode;
T0_P Op0;
T1_P Op1;
T2_P Op2;

TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
const T2_P &Op2)
: Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}

template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
assert(EO.Size == 3);
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
(Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
}

return false;
}
};

template <typename T0_P, typename T1_P, typename T2_P>
inline TernaryOpc_match<T0_P, T1_P, T2_P, false, false>
m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
Op2);
}

template <typename T0_P, typename T1_P, typename T2_P>
inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
Op2);
}

// === Binary operations ===
template <typename LHS_P, typename RHS_P, bool Commutable = false,
bool ExcludeChain = false>
Expand Down
45 changes: 14 additions & 31 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
return true;
}

if (N.getOpcode() != ISD::SETCC ||
N.getValueType().getScalarType() != MVT::i1 ||
cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
return false;

SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
assert(Op0.getValueType() == Op1.getValueType());

if (isNullOrNullSplat(Op0))
Op = Op1;
else if (isNullOrNullSplat(Op1))
Op = Op0;
else
if (N.getValueType().getScalarType() != MVT::i1 ||
!sd_match(
N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
return false;

Known = DAG.computeKnownBits(Op);

return (Known.Zero | 1).isAllOnes();
}

Expand Down Expand Up @@ -2544,26 +2532,22 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
return SDValue();

// Match the zext operand as a setcc of a boolean.
if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
Z.getOperand(0).getValueType() != MVT::i1)
if (Z.getOperand(0).getValueType() != MVT::i1)
return SDValue();

// Match the compare as: setcc (X & 1), 0, eq.
SDValue SetCC = Z.getOperand(0);
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
SetCC.getOperand(0).getOpcode() != ISD::AND ||
!isOneConstant(SetCC.getOperand(0).getOperand(1)))
if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
m_SpecificCondCode(ISD::SETEQ))))
return SDValue();

// We are adding/subtracting a constant and an inverted low bit. Turn that
// into a subtract/add of the low bit with incremented/decremented constant:
// add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
// sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
EVT VT = C.getValueType();
SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
: DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
}

Expand Down Expand Up @@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
return SDValue();

SDValue Cond0 = N0.getOperand(0);
SDValue Cond1 = N0.getOperand(1);
ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
if (VT != Cond0.getValueType())
SDValue Cond0, Cond1;
ISD::CondCode CC;
if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
m_CondCode(CC)))) ||
VT != Cond0.getValueType())
return SDValue();

// Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
Expand Down
35 changes: 35 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,41 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
}

TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);

SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);

SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT);
SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ);
SDValue ICMP_EQ10 = DAG->getSetCC(DL, MVT::i1, Op1, Op0, ISD::SETEQ);

using namespace SDPatternMatch;
ISD::CondCode CC;
EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
m_SpecificCondCode(ISD::SETUGT))));
EXPECT_TRUE(
sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC))));
EXPECT_TRUE(CC == ISD::SETUGT);
EXPECT_FALSE(sd_match(
ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_SpecificCondCode(ISD::SETLE))));

EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_TRUE(sd_match(ICMP_EQ10, m_c_SetCC(m_Specific(Op0), m_Specific(Op1),
m_SpecificCondCode(ISD::SETEQ))));
}

TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
Expand Down
Loading