Skip to content

[DAGCombiner][VP] add getNegative for VPMatchContext #80635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sunshaoce
Copy link
Contributor

This is my attempt to reuse existing code as much as possible, in order to provide a helper function for #80105.

@sunshaoce sunshaoce marked this pull request as ready for review February 5, 2024 06:09
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Feb 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Shao-Ce SUN (sunshaoce)

Changes

This is my attempt to reuse existing code as much as possible, in order to provide a helper function for #80105.


Full diff: https://github.com/llvm/llvm-project/pull/80635.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+33-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+31-24)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+25-2)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b9ec30754f0c3..22981a6284fac 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -88,6 +88,32 @@ class TargetMachine;
 class TargetSubtargetInfo;
 class Value;
 
+class VPMaskAndVL {
+  bool IsVP;
+  SDValue MaskOp;
+  SDValue VectorLenOp;
+
+public:
+  VPMaskAndVL(SDValue Mask, SDValue VectorLen)
+      : IsVP(true), MaskOp(Mask), VectorLenOp(VectorLen) {}
+  VPMaskAndVL() : IsVP(false), MaskOp(), VectorLenOp() {}
+
+  bool empty() const { return !IsVP; }
+  bool isMaskEqualsTo(const SDValue &Val) const { return MaskOp == Val; }
+  bool isVLEqualsTo(const SDValue &Val) const { return VectorLenOp == Val; }
+  SDValue getMask() const { return MaskOp; }
+  SDValue getVL() const { return VectorLenOp; }
+
+  SDValue setMask(SDValue Val) {
+    IsVP = true;
+    return MaskOp = Val;
+  }
+  SDValue setVL(SDValue Val) {
+    IsVP = true;
+    return VectorLenOp = Val;
+  }
+};
+
 template <typename T> class GenericSSAContext;
 using SSAContext = GenericSSAContext<Function>;
 template <typename T> class GenericUniformityInfo;
@@ -1004,7 +1030,8 @@ class SelectionDAG {
   SDValue getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT, EVT OpVT);
 
   /// Create negative operation as (SUB 0, Val).
-  SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT);
+  SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+                      VPMaskAndVL VPOp = VPMaskAndVL());
 
   /// Create a bitwise NOT operation as (XOR Val, -1).
   SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT);
@@ -1116,6 +1143,9 @@ class SelectionDAG {
                   ArrayRef<SDUse> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                   ArrayRef<SDValue> Ops, const SDNodeFlags Flags);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                  ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+                  VPMaskAndVL VPOp);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef<EVT> ResultTys,
                   ArrayRef<SDValue> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
@@ -1124,6 +1154,8 @@ class SelectionDAG {
   // Use flags from current flag inserter.
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                   ArrayRef<SDValue> Ops);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                  ArrayRef<SDValue> Ops, VPMaskAndVL VPOp);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
                   ArrayRef<SDValue> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3ce45e0e43bf4..3ed1a7533dd43 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -905,11 +905,16 @@ class EmptyMatchContext {
     return Opcode == OpN->getOpcode();
   }
 
-  // Same as SelectionDAG::getNode().
-  template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
-    return DAG.getNode(std::forward<ArgT>(Args)...);
+  // Same as SelectionDAG::FUNCT_NAME(Args...).
+#define GET_SELECTION_DAG_FUNCT(FUNCT_NAME)                                    \
+  template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) {             \
+    return DAG.FUNCT_NAME(std::forward<ArgT>(Args)...);                        \
   }
 
+  GET_SELECTION_DAG_FUNCT(getNode)
+  GET_SELECTION_DAG_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_FUNCT
+
   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
                                 bool LegalOnly = false) const {
     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
@@ -919,22 +924,21 @@ class EmptyMatchContext {
 class VPMatchContext {
   SelectionDAG &DAG;
   const TargetLowering &TLI;
-  SDValue RootMaskOp;
-  SDValue RootVectorLenOp;
+  VPMaskAndVL VPOp;
 
 public:
   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
+      : DAG(DAG), TLI(TLI), VPOp() {
     assert(Root->isVPOpcode());
     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
-      RootMaskOp = Root->getOperand(*RootMaskPos);
+      VPOp.setMask(Root->getOperand(*RootMaskPos));
     else if (Root->getOpcode() == ISD::VP_SELECT)
-      RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
-                                          Root->getOperand(0).getValueType());
+      VPOp.setMask(DAG.getAllOnesConstant(SDLoc(Root),
+                                          Root->getOperand(0).getValueType()));
 
     if (auto RootVLenPos =
             ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
-      RootVectorLenOp = Root->getOperand(*RootVLenPos);
+      VPOp.setVL(Root->getOperand(*RootVLenPos));
   }
 
   /// whether \p OpVal is a node that is functionally compatible with the
@@ -952,14 +956,14 @@ class VPMatchContext {
     unsigned VPOpcode = OpVal->getOpcode();
     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
       SDValue MaskOp = OpVal.getOperand(*MaskPos);
-      if (RootMaskOp != MaskOp &&
+      if (!VPOp.isMaskEqualsTo(MaskOp) &&
           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
         return false;
     }
 
     // Make sure the EVL of OpVal is same as Root's.
     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
-      if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
+      if (!VPOp.isVLEqualsTo(OpVal.getOperand(*VLenPos)))
         return false;
     return true;
   }
@@ -972,8 +976,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {Operand, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {Operand}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -981,8 +984,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -990,8 +992,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
@@ -999,8 +1000,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
-                       Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {Operand}, Flags, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1008,8 +1008,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
-                       Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, Flags, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1017,10 +1016,18 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, Flags, VPOp);
+  }
+
+  // Same as SelectionDAG::FUNCT_NAME(Args, VPOp).
+#define GET_SELECTION_DAG_VP_FUNCT(FUNCT_NAME)                                 \
+  template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) {             \
+    return DAG.FUNCT_NAME(std::forward<ArgT>(Args)..., VPOp);                  \
   }
 
+  GET_SELECTION_DAG_VP_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_VP_FUNCT
+
   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
                                 bool LegalOnly = false) const {
     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c1343836187a..0c6e70a86caf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1548,8 +1548,10 @@ SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
   return getZeroExtendInReg(Op, DL, VT);
 }
 
-SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
-  return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
+SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+                                  VPMaskAndVL VPOp) {
+  auto Opcode = VPOp.empty() ? ISD::SUB : ISD::VP_SUB;
+  return getNode(Opcode, DL, VT, {getConstant(0, DL, VT), Val}, VPOp);
 }
 
 /// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   return getNode(Opcode, DL, VT, Ops, Flags);
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              ArrayRef<SDValue> Ops, VPMaskAndVL VPOp) {
+  SmallVector<SDValue, 8> OpsVec(Ops);
+  if (!VPOp.empty()) {
+    OpsVec.push_back(VPOp.getMask());
+    OpsVec.push_back(VPOp.getVL());
+  }
+  return getNode(Opcode, DL, VT, OpsVec);
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
   unsigned NumOps = Ops.size();
@@ -9820,6 +9832,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   return V;
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+                              VPMaskAndVL VPOp) {
+  SmallVector<SDValue, 8> OpsVec(Ops);
+  if (!VPOp.empty()) {
+    OpsVec.push_back(VPOp.getMask());
+    OpsVec.push_back(VPOp.getVL());
+  }
+  return getNode(Opcode, DL, VT, OpsVec, Flags);
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
                               ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
   return getNode(Opcode, DL, getVTList(ResultTys), Ops);

@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(Opcode, DL, VT, Ops, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic seems different between this function and getNegative. getNegative will use the VP version SUB automatically if we have a non-null VPOp. But these 2 new getNode ask the user to chose right opcode even giving a non-null VPOp. It's a little bit inconsistent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants