Skip to content

[DAGCombiner][VP] add getNegative for VPMatchContext #80635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,32 @@ class TargetMachine;
class TargetSubtargetInfo;
class Value;

class VPMaskAndVL {
bool IsVP;
SDValue MaskOp;
SDValue VectorLenOp;

public:
VPMaskAndVL(SDValue Mask, SDValue VectorLen)
: IsVP(true), MaskOp(Mask), VectorLenOp(VectorLen) {}
VPMaskAndVL() : IsVP(false), MaskOp(), VectorLenOp() {}

bool empty() const { return !IsVP; }
bool isMaskEqualsTo(const SDValue &Val) const { return MaskOp == Val; }
bool isVLEqualsTo(const SDValue &Val) const { return VectorLenOp == Val; }
SDValue getMask() const { return MaskOp; }
SDValue getVL() const { return VectorLenOp; }

SDValue setMask(SDValue Val) {
IsVP = true;
return MaskOp = Val;
}
SDValue setVL(SDValue Val) {
IsVP = true;
return VectorLenOp = Val;
}
};

template <typename T> class GenericSSAContext;
using SSAContext = GenericSSAContext<Function>;
template <typename T> class GenericUniformityInfo;
Expand Down Expand Up @@ -1004,7 +1030,8 @@ class SelectionDAG {
SDValue getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT, EVT OpVT);

/// Create negative operation as (SUB 0, Val).
SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT);
SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT,
VPMaskAndVL VPOp = VPMaskAndVL());

/// Create a bitwise NOT operation as (XOR Val, -1).
SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT);
Expand Down Expand Up @@ -1116,6 +1143,9 @@ class SelectionDAG {
ArrayRef<SDUse> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef<EVT> ResultTys,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
Expand All @@ -1124,6 +1154,8 @@ class SelectionDAG {
// Use flags from current flag inserter.
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand);
Expand Down
55 changes: 31 additions & 24 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,11 +905,16 @@ class EmptyMatchContext {
return Opcode == OpN->getOpcode();
}

// Same as SelectionDAG::getNode().
template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
return DAG.getNode(std::forward<ArgT>(Args)...);
// Same as SelectionDAG::FUNCT_NAME(Args...).
#define GET_SELECTION_DAG_FUNCT(FUNCT_NAME) \
template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
return DAG.FUNCT_NAME(std::forward<ArgT>(Args)...); \
}

GET_SELECTION_DAG_FUNCT(getNode)
GET_SELECTION_DAG_FUNCT(getNegative)
#undef GET_SELECTION_DAG_FUNCT

bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
Expand All @@ -919,22 +924,21 @@ class EmptyMatchContext {
class VPMatchContext {
SelectionDAG &DAG;
const TargetLowering &TLI;
SDValue RootMaskOp;
SDValue RootVectorLenOp;
VPMaskAndVL VPOp;

public:
VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
: DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
: DAG(DAG), TLI(TLI), VPOp() {
assert(Root->isVPOpcode());
if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
RootMaskOp = Root->getOperand(*RootMaskPos);
VPOp.setMask(Root->getOperand(*RootMaskPos));
else if (Root->getOpcode() == ISD::VP_SELECT)
RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
Root->getOperand(0).getValueType());
VPOp.setMask(DAG.getAllOnesConstant(SDLoc(Root),
Root->getOperand(0).getValueType()));

if (auto RootVLenPos =
ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
RootVectorLenOp = Root->getOperand(*RootVLenPos);
VPOp.setVL(Root->getOperand(*RootVLenPos));
}

/// whether \p OpVal is a node that is functionally compatible with the
Expand All @@ -952,14 +956,14 @@ class VPMatchContext {
unsigned VPOpcode = OpVal->getOpcode();
if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
SDValue MaskOp = OpVal.getOperand(*MaskPos);
if (RootMaskOp != MaskOp &&
if (!VPOp.isMaskEqualsTo(MaskOp) &&
!ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
return false;
}

// Make sure the EVL of OpVal is same as Root's.
if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
if (!VPOp.isVLEqualsTo(OpVal.getOperand(*VLenPos)))
return false;
return true;
}
Expand All @@ -972,55 +976,58 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
return DAG.getNode(VPOpcode, DL, VT,
{Operand, RootMaskOp, RootVectorLenOp});
return DAG.getNode(VPOpcode, DL, VT, {Operand}, VPOp);
}

SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2) {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
return DAG.getNode(VPOpcode, DL, VT,
{N1, N2, RootMaskOp, RootVectorLenOp});
return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, VPOp);
}

SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3) {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
return DAG.getNode(VPOpcode, DL, VT,
{N1, N2, N3, RootMaskOp, RootVectorLenOp});
return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, VPOp);
}

SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
SDNodeFlags Flags) {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
Flags);
return DAG.getNode(VPOpcode, DL, VT, {Operand}, Flags, VPOp);
}

SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDNodeFlags Flags) {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
Flags);
return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, Flags, VPOp);
}

SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDNodeFlags Flags) {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
return DAG.getNode(VPOpcode, DL, VT,
{N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, Flags, VPOp);
}

// Same as SelectionDAG::FUNCT_NAME(Args, VPOp).
#define GET_SELECTION_DAG_VP_FUNCT(FUNCT_NAME) \
template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
return DAG.FUNCT_NAME(std::forward<ArgT>(Args)..., VPOp); \
}

GET_SELECTION_DAG_VP_FUNCT(getNegative)
#undef GET_SELECTION_DAG_VP_FUNCT

bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
Expand Down
27 changes: 25 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,8 +1548,10 @@ SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
return getZeroExtendInReg(Op, DL, VT);
}

SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT,
VPMaskAndVL VPOp) {
auto Opcode = VPOp.empty() ? ISD::SUB : ISD::VP_SUB;
return getNode(Opcode, DL, VT, {getConstant(0, DL, VT), Val}, VPOp);
}

/// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
Expand Down Expand Up @@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(Opcode, DL, VT, Ops, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic seems different between this function and getNegative. getNegative will use the VP version SUB automatically if we have a non-null VPOp. But these 2 new getNode ask the user to chose right opcode even giving a non-null VPOp. It's a little bit inconsistent.

ArrayRef<SDValue> Ops, VPMaskAndVL VPOp) {
SmallVector<SDValue, 8> OpsVec(Ops);
if (!VPOp.empty()) {
OpsVec.push_back(VPOp.getMask());
OpsVec.push_back(VPOp.getVL());
}
return getNode(Opcode, DL, VT, OpsVec);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
unsigned NumOps = Ops.size();
Expand Down Expand Up @@ -9820,6 +9832,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return V;
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
VPMaskAndVL VPOp) {
SmallVector<SDValue, 8> OpsVec(Ops);
if (!VPOp.empty()) {
OpsVec.push_back(VPOp.getMask());
OpsVec.push_back(VPOp.getVL());
}
return getNode(Opcode, DL, VT, OpsVec, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
return getNode(Opcode, DL, getVTList(ResultTys), Ops);
Expand Down