Skip to content

Commit 04e809a

Browse files
committed
[DAG] Add TargetLowering::expandABD and convert X86 lowering to use it
Scalar widening cases are still custom lowered in the X86 backend - we still need to add promotion/legalization support to handle these
1 parent 44e7b8a commit 04e809a

File tree

5 files changed

+54
-26
lines changed

5 files changed

+54
-26
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5033,6 +5033,11 @@ class TargetLowering : public TargetLoweringBase {
50335033
SDValue expandABS(SDNode *N, SelectionDAG &DAG,
50345034
bool IsNegative = false) const;
50355035

5036+
/// Expand ABDS/ABDU nodes. Expands vector/scalar ABDS/ABDU nodes.
5037+
/// \param N Node to expand
5038+
/// \returns The expansion result or SDValue() if it fails.
5039+
SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
5040+
50365041
/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
50375042
/// scalar types. Returns SDValue() if expand fails.
50385043
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,6 +2696,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
26962696
if ((Tmp1 = TLI.expandABS(Node, DAG)))
26972697
Results.push_back(Tmp1);
26982698
break;
2699+
case ISD::ABDS:
2700+
case ISD::ABDU:
2701+
if ((Tmp1 = TLI.expandABD(Node, DAG)))
2702+
Results.push_back(Tmp1);
2703+
break;
26992704
case ISD::CTPOP:
27002705
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
27012706
Results.push_back(Tmp1);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
795795
return;
796796
}
797797
break;
798+
case ISD::ABDS:
799+
case ISD::ABDU:
800+
if (SDValue Expanded = TLI.expandABD(Node, DAG)) {
801+
Results.push_back(Expanded);
802+
return;
803+
}
804+
break;
798805
case ISD::BITREVERSE:
799806
ExpandBITREVERSE(Node, Results);
800807
return;

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8627,6 +8627,38 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
86278627
return DAG.getNode(ISD::SUB, dl, VT, Shift, Xor);
86288628
}
86298629

8630+
SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
8631+
SDLoc dl(N);
8632+
EVT VT = N->getValueType(0);
8633+
SDValue LHS = DAG.getFreeze(N->getOperand(0));
8634+
SDValue RHS = DAG.getFreeze(N->getOperand(1));
8635+
bool IsSigned = N->getOpcode() == ISD::ABDS;
8636+
8637+
// abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
8638+
// abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
8639+
unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
8640+
unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
8641+
if (isOperationLegal(MaxOpc, VT) && isOperationLegal(MinOpc, VT)) {
8642+
SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
8643+
SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
8644+
return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
8645+
}
8646+
8647+
// abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
8648+
if (!IsSigned && isOperationLegal(ISD::USUBSAT, VT))
8649+
return DAG.getNode(ISD::OR, dl, VT,
8650+
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
8651+
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
8652+
8653+
// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
8654+
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
8655+
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
8656+
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
8657+
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
8658+
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
8659+
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
8660+
}
8661+
86308662
SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
86318663
SDLoc dl(N);
86328664
EVT VT = N->getValueType(0);

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30375,36 +30375,20 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
3037530375
if ((VT == MVT::v32i16 || VT == MVT::v64i8) && !Subtarget.useBWIRegs())
3037630376
return splitVectorIntBinary(Op, DAG);
3037730377

30378-
// TODO: Add TargetLowering expandABD() support.
3037930378
SDLoc dl(Op);
3038030379
bool IsSigned = Op.getOpcode() == ISD::ABDS;
30381-
SDValue LHS = DAG.getFreeze(Op.getOperand(0));
30382-
SDValue RHS = DAG.getFreeze(Op.getOperand(1));
3038330380
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
3038430381

30385-
// abds(lhs, rhs) -> sub(smax(lhs,rhs), smin(lhs,rhs))
30386-
// abdu(lhs, rhs) -> sub(umax(lhs,rhs), umin(lhs,rhs))
30387-
unsigned MaxOpc = IsSigned ? ISD::SMAX : ISD::UMAX;
30388-
unsigned MinOpc = IsSigned ? ISD::SMIN : ISD::UMIN;
30389-
if (TLI.isOperationLegal(MaxOpc, VT) && TLI.isOperationLegal(MinOpc, VT)) {
30390-
SDValue Max = DAG.getNode(MaxOpc, dl, VT, LHS, RHS);
30391-
SDValue Min = DAG.getNode(MinOpc, dl, VT, LHS, RHS);
30392-
return DAG.getNode(ISD::SUB, dl, VT, Max, Min);
30393-
}
30394-
30395-
// abdu(lhs, rhs) -> or(usubsat(lhs,rhs), usubsat(rhs,lhs))
30396-
if (!IsSigned && TLI.isOperationLegal(ISD::USUBSAT, VT))
30397-
return DAG.getNode(ISD::OR, dl, VT,
30398-
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
30399-
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
30400-
30382+
// TODO: Move to TargetLowering expandABD() once we have ABD promotion.
3040130383
if (VT.isScalarInteger()) {
3040230384
unsigned WideBits = std::max<unsigned>(2 * VT.getScalarSizeInBits(), 32u);
3040330385
MVT WideVT = MVT::getIntegerVT(WideBits);
3040430386
if (TLI.isTypeLegal(WideVT)) {
3040530387
// abds(lhs, rhs) -> trunc(abs(sub(sext(lhs), sext(rhs))))
3040630388
// abdu(lhs, rhs) -> trunc(abs(sub(zext(lhs), zext(rhs))))
3040730389
unsigned ExtOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
30390+
SDValue LHS = DAG.getFreeze(Op.getOperand(0));
30391+
SDValue RHS = DAG.getFreeze(Op.getOperand(1));
3040830392
LHS = DAG.getNode(ExtOpc, dl, WideVT, LHS);
3040930393
RHS = DAG.getNode(ExtOpc, dl, WideVT, RHS);
3041030394
SDValue Diff = DAG.getNode(ISD::SUB, dl, WideVT, LHS, RHS);
@@ -30413,13 +30397,8 @@ static SDValue LowerABD(SDValue Op, const X86Subtarget &Subtarget,
3041330397
}
3041430398
}
3041530399

30416-
// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
30417-
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
30418-
EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
30419-
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
30420-
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
30421-
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
30422-
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
30400+
// Default to expand.
30401+
return SDValue();
3042330402
}
3042430403

3042530404
static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,

0 commit comments

Comments
 (0)