Skip to content

Commit 2887f14

Browse files
committed
[ISel] Port AArch64 SABD and UABD to DAGCombine
This ports the AArch64 SABD and USBD over to DAG Combine, where they can be used by more backends (notably MVE in a follow-up patch). The matching code has changed very little, just to handle legal operations and types differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing a ubds/abdu which is zexted to the original type. Differential Revision: https://reviews.llvm.org/D91937
1 parent 8c2d462 commit 2887f14

File tree

8 files changed

+69
-63
lines changed

8 files changed

+69
-63
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,13 @@ enum NodeType {
611611
MULHU,
612612
MULHS,
613613

614+
// ABDS/ABDU - Absolute difference - Return the absolute difference between
615+
// two numbers interpreted as signed/unsigned.
616+
// i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
617+
// or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
618+
ABDS,
619+
ABDU,
620+
614621
/// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
615622
/// integers.
616623
SMIN,

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
369369
[SDNPCommutative, SDNPAssociative]>;
370370
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
371371
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
372+
def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
373+
def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
372374
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
373375
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
374376
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
90719071
return SDValue();
90729072
}
90739073

9074+
// Given a ABS node, detect the following pattern:
9075+
// (ABS (SUB (EXTEND a), (EXTEND b))).
9076+
// Generates UABD/SABD instruction.
9077+
static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
9078+
const TargetLowering &TLI) {
9079+
SDValue AbsOp1 = N->getOperand(0);
9080+
SDValue Op0, Op1;
9081+
9082+
if (AbsOp1.getOpcode() != ISD::SUB)
9083+
return SDValue();
9084+
9085+
Op0 = AbsOp1.getOperand(0);
9086+
Op1 = AbsOp1.getOperand(1);
9087+
9088+
unsigned Opc0 = Op0.getOpcode();
9089+
// Check if the operands of the sub are (zero|sign)-extended.
9090+
if (Opc0 != Op1.getOpcode() ||
9091+
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
9092+
return SDValue();
9093+
9094+
EVT VT1 = Op0.getOperand(0).getValueType();
9095+
EVT VT2 = Op1.getOperand(0).getValueType();
9096+
// Check if the operands are of same type and valid size.
9097+
unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
9098+
if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
9099+
return SDValue();
9100+
9101+
Op0 = Op0.getOperand(0);
9102+
Op1 = Op1.getOperand(0);
9103+
SDValue ABD =
9104+
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
9105+
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
9106+
}
9107+
90749108
SDValue DAGCombiner::visitABS(SDNode *N) {
90759109
SDValue N0 = N->getOperand(0);
90769110
EVT VT = N->getValueType(0);
@@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
90849118
// fold (abs x) -> x iff not-negative
90859119
if (DAG.SignBitIsZero(N0))
90869120
return N0;
9121+
9122+
if (SDValue ABD = combineABSToABD(N, DAG, TLI))
9123+
return ABD;
9124+
90879125
return SDValue();
90889126
}
90899127

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
231231
case ISD::MUL: return "mul";
232232
case ISD::MULHU: return "mulhu";
233233
case ISD::MULHS: return "mulhs";
234+
case ISD::ABDS: return "abds";
235+
case ISD::ABDU: return "abdu";
234236
case ISD::SDIV: return "sdiv";
235237
case ISD::UDIV: return "udiv";
236238
case ISD::SREM: return "srem";

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
813813
setOperationAction(ISD::SUBC, VT, Expand);
814814
setOperationAction(ISD::SUBE, VT, Expand);
815815

816+
// Absolute difference
817+
setOperationAction(ISD::ABDS, VT, Expand);
818+
setOperationAction(ISD::ABDU, VT, Expand);
819+
816820
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
817821
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
818822
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
10501050
setOperationAction(ISD::USUBSAT, VT, Legal);
10511051
}
10521052

1053+
for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
1054+
MVT::v4i32}) {
1055+
setOperationAction(ISD::ABDS, VT, Legal);
1056+
setOperationAction(ISD::ABDU, VT, Legal);
1057+
}
1058+
10531059
// Vector reductions
10541060
for (MVT VT : { MVT::v4f16, MVT::v2f32,
10551061
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
21162122
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
21172123
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
21182124
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
2119-
MAKE_CASE(AArch64ISD::UABD)
2120-
MAKE_CASE(AArch64ISD::SABD)
21212125
MAKE_CASE(AArch64ISD::UADDLP)
21222126
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
21232127
}
@@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
40824086
}
40834087
case Intrinsic::aarch64_neon_sabd:
40844088
case Intrinsic::aarch64_neon_uabd: {
4085-
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
4086-
: AArch64ISD::SABD;
4089+
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
4090+
: ISD::ABDS;
40874091
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
40884092
Op.getOperand(2));
40894093
}
@@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1209912103
SDValue UABDHigh8Op1 =
1210012104
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
1210112105
DAG.getConstant(8, DL, MVT::i64));
12102-
SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
12103-
DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
12106+
SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
12107+
UABDHigh8Op0, UABDHigh8Op1);
1210412108
SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
1210512109

1210612110
// Second, create the node pattern of UABAL.
@@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1211012114
SDValue UABDLo8Op1 =
1211112115
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
1211212116
DAG.getConstant(0, DL, MVT::i64));
12113-
SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
12114-
DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
12117+
SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
12118+
UABDLo8Op0, UABDLo8Op1);
1211512119
SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
1211612120
SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
1211712121

@@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1217012174
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
1217112175
}
1217212176

12173-
// Given a ABS node, detect the following pattern:
12174-
// (ABS (SUB (EXTEND a), (EXTEND b))).
12175-
// Generates UABD/SABD instruction.
12176-
static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
12177-
TargetLowering::DAGCombinerInfo &DCI,
12178-
const AArch64Subtarget *Subtarget) {
12179-
SDValue AbsOp1 = N->getOperand(0);
12180-
SDValue Op0, Op1;
12181-
12182-
if (AbsOp1.getOpcode() != ISD::SUB)
12183-
return SDValue();
12184-
12185-
Op0 = AbsOp1.getOperand(0);
12186-
Op1 = AbsOp1.getOperand(1);
12187-
12188-
unsigned Opc0 = Op0.getOpcode();
12189-
// Check if the operands of the sub are (zero|sign)-extended.
12190-
if (Opc0 != Op1.getOpcode() ||
12191-
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
12192-
return SDValue();
12193-
12194-
EVT VectorT1 = Op0.getOperand(0).getValueType();
12195-
EVT VectorT2 = Op1.getOperand(0).getValueType();
12196-
// Check if vectors are of same type and valid size.
12197-
uint64_t Size = VectorT1.getFixedSizeInBits();
12198-
if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
12199-
return SDValue();
12200-
12201-
// Check if vector element types are valid.
12202-
EVT VT1 = VectorT1.getVectorElementType();
12203-
if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
12204-
return SDValue();
12205-
12206-
Op0 = Op0.getOperand(0);
12207-
Op1 = Op1.getOperand(0);
12208-
unsigned ABDOpcode =
12209-
(Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
12210-
SDValue ABD =
12211-
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
12212-
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
12213-
}
12214-
1221512177
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
1221612178
TargetLowering::DAGCombinerInfo &DCI,
1221712179
const AArch64Subtarget *Subtarget) {
@@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
1437714339
// helps the backend to decide that an sabdl2 would be useful, saving a real
1437814340
// extract_high operation.
1437914341
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
14380-
(N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
14381-
N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
14342+
(N->getOperand(0).getOpcode() == ISD::ABDU ||
14343+
N->getOperand(0).getOpcode() == ISD::ABDS)) {
1438214344
SDNode *ABDNode = N->getOperand(0).getNode();
1438314345
SDValue NewABD =
1438414346
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
@@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
1634416306
default:
1634516307
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
1634616308
break;
16347-
case ISD::ABS:
16348-
return performABSCombine(N, DAG, DCI, Subtarget);
1634916309
case ISD::ADD:
1635016310
case ISD::SUB:
1635116311
return performAddSubCombine(N, DCI, DAG);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,6 @@ enum NodeType : unsigned {
236236
SRHADD,
237237
URHADD,
238238

239-
// Absolute difference
240-
UABD,
241-
SABD,
242-
243239
// Unsigned Add Long Pairwise
244240
UADDLP,
245241

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
579579
def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
580580
def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;
581581

582-
def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
583-
def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;
584-
585582
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
586-
[(AArch64uabd_n node:$lhs, node:$rhs),
583+
[(abdu node:$lhs, node:$rhs),
587584
(int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
588585
def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs),
589-
[(AArch64sabd_n node:$lhs, node:$rhs),
586+
[(abds node:$lhs, node:$rhs),
590587
(int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
591588

592589
def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;

0 commit comments

Comments
 (0)