Skip to content

[DAG] Add SDPatternMatch::m_SetCC and update some combines to use it #98646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 14, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Jul 12, 2024

The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)

@RKSimon RKSimon requested a review from mshockwave July 12, 2024 15:05
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Jul 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Simon Pilgrim (RKSimon)

Changes

The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)


Full diff: https://github.com/llvm/llvm-project/pull/98646.diff

2 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+43)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-31)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index f39fbd95b3beb..ee2c0610c92b8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
   explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
 };
 
+// === Ternary operations ===
+template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
+          bool ExcludeChain = false>
+struct TernaryOpc_match {
+  unsigned Opcode;
+  T0_P Op0;
+  T1_P Op1;
+  T2_P Op2;
+
+  TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
+                   const T2_P &Op2)
+      : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
+      EffectiveOperands<ExcludeChain> EO(N);
+      assert(EO.Size == 3);
+      return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+              (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
+             Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
+    }
+
+    return false;
+  }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
+                                                          Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
+                                                         Op2);
+}
+
 // === Binary operations ===
 template <typename LHS_P, typename RHS_P, bool Commutable = false,
           bool ExcludeChain = false>
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 428cdda21cd41..9185c0176de12 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
     return true;
   }
 
-  if (N.getOpcode() != ISD::SETCC ||
-      N.getValueType().getScalarType() != MVT::i1 ||
-      cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
-    return false;
-
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  assert(Op0.getValueType() == Op1.getValueType());
-
-  if (isNullOrNullSplat(Op0))
-    Op = Op1;
-  else if (isNullOrNullSplat(Op1))
-    Op = Op0;
-  else
+  if (N.getValueType().getScalarType() != MVT::i1 ||
+      !sd_match(N,
+                m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
     return false;
 
   Known = DAG.computeKnownBits(Op);
-
   return (Known.Zero | 1).isAllOnes();
 }
 
@@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
     return SDValue();
 
   // Match the zext operand as a setcc of a boolean.
-  if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
-      Z.getOperand(0).getValueType() != MVT::i1)
+  if (Z.getOperand(0).getValueType() != MVT::i1)
     return SDValue();
 
   // Match the compare as: setcc (X & 1), 0, eq.
-  SDValue SetCC = Z.getOperand(0);
-  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
-  if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
-      SetCC.getOperand(0).getOpcode() != ISD::AND ||
-      !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+  if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
+                                         m_SpecificCondCode(ISD::SETEQ))))
     return SDValue();
 
   // We are adding/subtracting a constant and an inverted low bit. Turn that
@@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
   EVT VT = C.getValueType();
-  SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
-  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
-                       DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+  SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
+  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
+                     : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
 }
 
@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
   SDValue N1 = N->getOperand(1);
   SDValue N2 = N->getOperand(2);
   EVT VT = N->getValueType(0);
-  if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
-    return SDValue();
 
-  SDValue Cond0 = N0.getOperand(0);
-  SDValue Cond1 = N0.getOperand(1);
-  ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
-  if (VT != Cond0.getValueType())
+  SDValue Cond0, Cond1;
+  ISD::CondCode CC;
+  if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
+                                     m_CondCode(CC)))) ||
+      VT != Cond0.getValueType())
     return SDValue();
 
   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the

Copy link

github-actions bot commented Jul 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@RKSimon RKSimon force-pushed the sd_match_setcc branch 2 times, most recently from b633bae to aa7a65b Compare July 12, 2024 16:51
Copy link
Member

@mshockwave mshockwave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some (unit) tests?

Copy link
Member

@mshockwave mshockwave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks

@RKSimon RKSimon merged commit 61a4e1e into llvm:main Jul 14, 2024
7 checks passed
@RKSimon RKSimon deleted the sd_match_setcc branch July 14, 2024 16:18
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…lvm#98646)

The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants