-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[DAGCombiner][VP] add getNegative for VPMatchContext #80635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-selectiondag Author: Shao-Ce SUN (sunshaoce) ChangesThis is my attempt to reuse existing code as much as possible, in order to provide a helper function for #80105. Full diff: https://github.com/llvm/llvm-project/pull/80635.diff 3 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b9ec30754f0c3..22981a6284fac 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -88,6 +88,32 @@ class TargetMachine;
class TargetSubtargetInfo;
class Value;
+class VPMaskAndVL {
+ bool IsVP;
+ SDValue MaskOp;
+ SDValue VectorLenOp;
+
+public:
+ VPMaskAndVL(SDValue Mask, SDValue VectorLen)
+ : IsVP(true), MaskOp(Mask), VectorLenOp(VectorLen) {}
+ VPMaskAndVL() : IsVP(false), MaskOp(), VectorLenOp() {}
+
+ bool empty() const { return !IsVP; }
+ bool isMaskEqualsTo(const SDValue &Val) const { return MaskOp == Val; }
+ bool isVLEqualsTo(const SDValue &Val) const { return VectorLenOp == Val; }
+ SDValue getMask() const { return MaskOp; }
+ SDValue getVL() const { return VectorLenOp; }
+
+ SDValue setMask(SDValue Val) {
+ IsVP = true;
+ return MaskOp = Val;
+ }
+ SDValue setVL(SDValue Val) {
+ IsVP = true;
+ return VectorLenOp = Val;
+ }
+};
+
template <typename T> class GenericSSAContext;
using SSAContext = GenericSSAContext<Function>;
template <typename T> class GenericUniformityInfo;
@@ -1004,7 +1030,8 @@ class SelectionDAG {
SDValue getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT, EVT OpVT);
/// Create negative operation as (SUB 0, Val).
- SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT);
+ SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+ VPMaskAndVL VPOp = VPMaskAndVL());
/// Create a bitwise NOT operation as (XOR Val, -1).
SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT);
@@ -1116,6 +1143,9 @@ class SelectionDAG {
ArrayRef<SDUse> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+ VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef<EVT> ResultTys,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
@@ -1124,6 +1154,8 @@ class SelectionDAG {
// Use flags from current flag inserter.
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3ce45e0e43bf4..3ed1a7533dd43 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -905,11 +905,16 @@ class EmptyMatchContext {
return Opcode == OpN->getOpcode();
}
- // Same as SelectionDAG::getNode().
- template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
- return DAG.getNode(std::forward<ArgT>(Args)...);
+ // Same as SelectionDAG::FUNCT_NAME(Args...).
+#define GET_SELECTION_DAG_FUNCT(FUNCT_NAME) \
+ template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
+ return DAG.FUNCT_NAME(std::forward<ArgT>(Args)...); \
}
+ GET_SELECTION_DAG_FUNCT(getNode)
+ GET_SELECTION_DAG_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_FUNCT
+
bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
@@ -919,22 +924,21 @@ class EmptyMatchContext {
class VPMatchContext {
SelectionDAG &DAG;
const TargetLowering &TLI;
- SDValue RootMaskOp;
- SDValue RootVectorLenOp;
+ VPMaskAndVL VPOp;
public:
VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
- : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
+ : DAG(DAG), TLI(TLI), VPOp() {
assert(Root->isVPOpcode());
if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
- RootMaskOp = Root->getOperand(*RootMaskPos);
+ VPOp.setMask(Root->getOperand(*RootMaskPos));
else if (Root->getOpcode() == ISD::VP_SELECT)
- RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
- Root->getOperand(0).getValueType());
+ VPOp.setMask(DAG.getAllOnesConstant(SDLoc(Root),
+ Root->getOperand(0).getValueType()));
if (auto RootVLenPos =
ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
- RootVectorLenOp = Root->getOperand(*RootVLenPos);
+ VPOp.setVL(Root->getOperand(*RootVLenPos));
}
/// whether \p OpVal is a node that is functionally compatible with the
@@ -952,14 +956,14 @@ class VPMatchContext {
unsigned VPOpcode = OpVal->getOpcode();
if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
SDValue MaskOp = OpVal.getOperand(*MaskPos);
- if (RootMaskOp != MaskOp &&
+ if (!VPOp.isMaskEqualsTo(MaskOp) &&
!ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
return false;
}
// Make sure the EVL of OpVal is same as Root's.
if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
- if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
+ if (!VPOp.isVLEqualsTo(OpVal.getOperand(*VLenPos)))
return false;
return true;
}
@@ -972,8 +976,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT,
- {Operand, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {Operand}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -981,8 +984,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -990,8 +992,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
@@ -999,8 +1000,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
- Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {Operand}, Flags, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1008,8 +1008,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
- Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, Flags, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1017,10 +1016,18 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, Flags, VPOp);
+ }
+
+ // Same as SelectionDAG::FUNCT_NAME(Args, VPOp).
+#define GET_SELECTION_DAG_VP_FUNCT(FUNCT_NAME) \
+ template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
+ return DAG.FUNCT_NAME(std::forward<ArgT>(Args)..., VPOp); \
}
+ GET_SELECTION_DAG_VP_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_VP_FUNCT
+
bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c1343836187a..0c6e70a86caf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1548,8 +1548,10 @@ SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
return getZeroExtendInReg(Op, DL, VT);
}
-SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
- return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
+SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+ VPMaskAndVL VPOp) {
+ auto Opcode = VPOp.empty() ? ISD::SUB : ISD::VP_SUB;
+ return getNode(Opcode, DL, VT, {getConstant(0, DL, VT), Val}, VPOp);
}
/// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(Opcode, DL, VT, Ops, Flags);
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, VPMaskAndVL VPOp) {
+ SmallVector<SDValue, 8> OpsVec(Ops);
+ if (!VPOp.empty()) {
+ OpsVec.push_back(VPOp.getMask());
+ OpsVec.push_back(VPOp.getVL());
+ }
+ return getNode(Opcode, DL, VT, OpsVec);
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
unsigned NumOps = Ops.size();
@@ -9820,6 +9832,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return V;
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+ VPMaskAndVL VPOp) {
+ SmallVector<SDValue, 8> OpsVec(Ops);
+ if (!VPOp.empty()) {
+ OpsVec.push_back(VPOp.getMask());
+ OpsVec.push_back(VPOp.getVL());
+ }
+ return getNode(Opcode, DL, VT, OpsVec, Flags);
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
return getNode(Opcode, DL, getVTList(ResultTys), Ops);
|
@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, | |||
return getNode(Opcode, DL, VT, Ops, Flags); | |||
} | |||
|
|||
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic seems different between this function and getNegative
. getNegative
will use the VP version SUB
automatically if we have a non-null VPOp
. But these 2 new getNode
ask the user to chose right opcode even giving a non-null VPOp
. It's a little bit inconsistent.
This is my attempt to reuse existing code as much as possible, in order to provide a helper function for #80105.