Skip to content

Commit b1234dd

Browse files
authored
[DAG] Add legalization handling for ABDS/ABDU (#92576)
Always match ABD patterns pre-legalization, and use TargetLowering::expandABD to expand again during legalization. abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), usub_overflow(lhs, rhs)), usub_overflow(lhs, rhs)) Alive2: https://alive2.llvm.org/ce/z/dVdMyv
1 parent bb59f04 commit b1234dd

28 files changed

+3198
-4078
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4089,13 +4089,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
40894089
}
40904090

40914091
// smax(a,b) - smin(a,b) --> abds(a,b)
4092-
if (hasOperation(ISD::ABDS, VT) &&
4092+
if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
40934093
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
40944094
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
40954095
return DAG.getNode(ISD::ABDS, DL, VT, A, B);
40964096

40974097
// umax(a,b) - umin(a,b) --> abdu(a,b)
4098-
if (hasOperation(ISD::ABDU, VT) &&
4098+
if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
40994099
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
41004100
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
41014101
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
@@ -10922,6 +10922,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1092210922
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
1092310923
Opc0 != ISD::SIGN_EXTEND_INREG)) {
1092410924
// fold (abs (sub nsw x, y)) -> abds(x, y)
10925+
// Don't fold this for unsupported types as we lose the NSW handling.
1092510926
if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
1092610927
TLI.preferABDSToABSWithNSW(VT)) {
1092710928
SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
@@ -10944,7 +10945,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1094410945
// fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
1094510946
EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
1094610947
if ((VT0 == MaxVT || Op0->hasOneUse()) &&
10947-
(VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
10948+
(VT1 == MaxVT || Op1->hasOneUse()) &&
10949+
(!LegalOperations || hasOperation(ABDOpcode, MaxVT))) {
1094810950
SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
1094910951
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
1095010952
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
@@ -10954,7 +10956,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
1095410956

1095510957
// fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
1095610958
// fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
10957-
if (hasOperation(ABDOpcode, VT)) {
10959+
if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
1095810960
SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
1095910961
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
1096010962
}
@@ -12376,7 +12378,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1237612378
N1.getOperand(1) == N2.getOperand(0)) {
1237712379
bool IsSigned = isSignedIntSetCC(CC);
1237812380
unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
12379-
if (hasOperation(ABDOpc, VT)) {
12381+
if (!LegalOperations || hasOperation(ABDOpc, VT)) {
1238012382
switch (CC) {
1238112383
case ISD::SETGT:
1238212384
case ISD::SETGE:

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,21 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
192192
case ISD::VP_SUB:
193193
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
194194

195+
case ISD::ABDS:
195196
case ISD::AVGCEILS:
196197
case ISD::AVGFLOORS:
198+
197199
case ISD::VP_SMIN:
198200
case ISD::VP_SMAX:
199201
case ISD::SDIV:
200202
case ISD::SREM:
201203
case ISD::VP_SDIV:
202204
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
203205

206+
case ISD::ABDU:
204207
case ISD::AVGCEILU:
205208
case ISD::AVGFLOORU:
209+
206210
case ISD::VP_UMIN:
207211
case ISD::VP_UMAX:
208212
case ISD::UDIV:
@@ -2791,6 +2795,8 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
27912795
case ISD::PARITY: ExpandIntRes_PARITY(N, Lo, Hi); break;
27922796
case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break;
27932797
case ISD::ABS: ExpandIntRes_ABS(N, Lo, Hi); break;
2798+
case ISD::ABDS:
2799+
case ISD::ABDU: ExpandIntRes_ABD(N, Lo, Hi); break;
27942800
case ISD::CTLZ_ZERO_UNDEF:
27952801
case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break;
27962802
case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break;
@@ -3850,6 +3856,11 @@ void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N,
38503856
Hi = DAG.getConstant(0, dl, NVT);
38513857
}
38523858

3859+
void DAGTypeLegalizer::ExpandIntRes_ABD(SDNode *N, SDValue &Lo, SDValue &Hi) {
3860+
SDValue Result = TLI.expandABD(N, DAG);
3861+
SplitInteger(Result, Lo, Hi);
3862+
}
3863+
38533864
void DAGTypeLegalizer::ExpandIntRes_CTPOP(SDNode *N,
38543865
SDValue &Lo, SDValue &Hi) {
38553866
SDLoc dl(N);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
448448
void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi);
449449
void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi);
450450
void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi);
451+
void ExpandIntRes_ABD (SDNode *N, SDValue &Lo, SDValue &Hi);
451452
void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi);
452453
void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi);
453454
void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
147147
case ISD::FMINIMUM:
148148
case ISD::FMAXIMUM:
149149
case ISD::FLDEXP:
150+
case ISD::ABDS:
151+
case ISD::ABDU:
150152
case ISD::SMIN:
151153
case ISD::SMAX:
152154
case ISD::UMIN:
@@ -1233,6 +1235,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
12331235
case ISD::MUL: case ISD::VP_MUL:
12341236
case ISD::MULHS:
12351237
case ISD::MULHU:
1238+
case ISD::ABDS:
1239+
case ISD::ABDU:
12361240
case ISD::AVGCEILS:
12371241
case ISD::AVGCEILU:
12381242
case ISD::AVGFLOORS:
@@ -4368,6 +4372,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
43684372
case ISD::MUL: case ISD::VP_MUL:
43694373
case ISD::MULHS:
43704374
case ISD::MULHU:
4375+
case ISD::ABDS:
4376+
case ISD::ABDU:
43714377
case ISD::OR: case ISD::VP_OR:
43724378
case ISD::SUB: case ISD::VP_SUB:
43734379
case ISD::XOR: case ISD::VP_XOR:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7024,6 +7024,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
70247024
assert(VT.isInteger() && "This operator does not apply to FP types!");
70257025
assert(N1.getValueType() == N2.getValueType() &&
70267026
N1.getValueType() == VT && "Binary operator types must match!");
7027+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
7028+
return getNode(ISD::XOR, DL, VT, N1, N2);
70277029
break;
70287030
case ISD::SMIN:
70297031
case ISD::UMAX:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9311,6 +9311,21 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
93119311
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
93129312
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));
93139313

9314+
// If the subtract doesn't overflow then just use abs(sub())
9315+
// NOTE: don't use frozen operands for value tracking.
9316+
bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
9317+
DAG.SignBitIsZero(N->getOperand(0));
9318+
9319+
if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
9320+
N->getOperand(1)))
9321+
return DAG.getNode(ISD::ABS, dl, VT,
9322+
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));
9323+
9324+
if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
9325+
N->getOperand(0)))
9326+
return DAG.getNode(ISD::ABS, dl, VT,
9327+
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
9328+
93149329
EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
93159330
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
93169331
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
@@ -9324,6 +9339,23 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
93249339
return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
93259340
}
93269341

9342+
// Similar to the branchless expansion, use the (sign-extended) usubo overflow
9343+
// flag if the (scalar) type is illegal as this is more likely to legalize
9344+
// cleanly:
9345+
// abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), uof(lhs, rhs)), uof(lhs, rhs))
9346+
if (!IsSigned && VT.isScalarInteger() && !isTypeLegal(VT)) {
9347+
SDValue USubO =
9348+
DAG.getNode(ISD::USUBO, dl, DAG.getVTList(VT, MVT::i1), {LHS, RHS});
9349+
SDValue Cmp = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, USubO.getValue(1));
9350+
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, USubO.getValue(0), Cmp);
9351+
return DAG.getNode(ISD::SUB, dl, VT, Xor, Cmp);
9352+
}
9353+
9354+
// FIXME: Should really try to split the vector in case it's legal on a
9355+
// subvector.
9356+
if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
9357+
return DAG.UnrollVectorOp(N);
9358+
93279359
// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
93289360
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
93299361
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),

0 commit comments

Comments
 (0)