Skip to content

Commit 4072e36

Browse files
committed
[ISel] Port AArch64 HADD and RHADD to ISel
This ports the aarch64 combines for HADD and RHADD over to DAG combine, so that they can be used in more architectures (notably MVE in a followup patch). They are renamed to AVGFLOOR and AVGCEIL in the process, to avoid confusion with instructions such as X86 hadd. The code was also rewritten slightly to remove the AArch64 idiosyncrasies. The general pattern for a AVGFLOORS is %xe = sext i8 %x to i32 %ye = sext i8 %y to i32 %a = add i32 %xe, %ye %r = lshr i32 %a, 1 %t = trunc i32 %r to i8 An AVGFLOORU is equivalent with zext. Because of the truncate lshr==ashr, as the top bits are not demanded. An AVGCEIL also includes an extra rounding, so includes an extra add of 1. Differential Revision: https://reviews.llvm.org/D106237
1 parent d828281 commit 4072e36

File tree

10 files changed

+139
-122
lines changed

10 files changed

+139
-122
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,17 @@ enum NodeType {
617617
MULHU,
618618
MULHS,
619619

620+
/// AVGFLOORS/AVGFLOORU - Averaging add - Add two integers using an integer of
621+
/// type i[N+1], halving the result by shifting it one bit right.
622+
/// shr(add(ext(X), ext(Y)), 1)
623+
AVGFLOORS,
624+
AVGFLOORU,
625+
/// AVGCEILS/AVGCEILU - Rounding averaging add - Add two integers using an
626+
/// integer of type i[N+2], add 1 and halve the result by shifting it one bit
627+
/// right. shr(add(ext(X), ext(Y), 1), 1)
628+
AVGCEILS,
629+
AVGCEILU,
630+
620631
// ABDS/ABDU - Absolute difference - Return the absolute difference between
621632
// two numbers interpreted as signed/unsigned.
622633
// i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,6 +2515,10 @@ class TargetLoweringBase {
25152515
case ISD::FMAXNUM_IEEE:
25162516
case ISD::FMINIMUM:
25172517
case ISD::FMAXIMUM:
2518+
case ISD::AVGFLOORS:
2519+
case ISD::AVGFLOORU:
2520+
case ISD::AVGCEILS:
2521+
case ISD::AVGCEILU:
25182522
return true;
25192523
default: return false;
25202524
}

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,10 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
365365
[SDNPCommutative, SDNPAssociative]>;
366366
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
367367
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
368+
def avgfloors : SDNode<"ISD::AVGFLOORS" , SDTIntBinOp, [SDNPCommutative]>;
369+
def avgflooru : SDNode<"ISD::AVGFLOORU" , SDTIntBinOp, [SDNPCommutative]>;
370+
def avgceils : SDNode<"ISD::AVGCEILS" , SDTIntBinOp, [SDNPCommutative]>;
371+
def avgceilu : SDNode<"ISD::AVGCEILU" , SDTIntBinOp, [SDNPCommutative]>;
368372
def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
369373
def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
370374
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12688,6 +12688,87 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
1268812688
return SDValue();
1268912689
}
1269012690

12691+
// Attempt to form one of the avg patterns from:
12692+
// truncate(shr(add(zext(OpB), zext(OpA)), 1))
12693+
// Creating avgflooru/avgfloors/avgceilu/avgceils, with the ceiling having an
12694+
// extra rounding add:
12695+
// truncate(shr(add(zext(OpB), zext(OpA), 1), 1))
12696+
// This starts at a truncate, meaning the shift will always be shl, as the top
12697+
// bits are known to not be demanded.
12698+
static SDValue performAvgCombine(SDNode *N, SelectionDAG &DAG) {
12699+
assert(N->getOpcode() == ISD::TRUNCATE && "TRUNCATE node expected");
12700+
EVT VT = N->getValueType(0);
12701+
12702+
SDValue Shift = N->getOperand(0);
12703+
if (Shift.getOpcode() != ISD::SRL)
12704+
return SDValue();
12705+
12706+
// Is the right shift using an immediate value of 1?
12707+
ConstantSDNode *N1C = isConstOrConstSplat(Shift.getOperand(1));
12708+
if (!N1C || !N1C->isOne())
12709+
return SDValue();
12710+
12711+
// We are looking for an avgfloor
12712+
// add(ext, ext)
12713+
// or one of these as a avgceil
12714+
// add(add(ext, ext), 1)
12715+
// add(add(ext, 1), ext)
12716+
// add(ext, add(ext, 1))
12717+
SDValue Add = Shift.getOperand(0);
12718+
if (Add.getOpcode() != ISD::ADD)
12719+
return SDValue();
12720+
12721+
SDValue ExtendOpA = Add.getOperand(0);
12722+
SDValue ExtendOpB = Add.getOperand(1);
12723+
auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3) {
12724+
ConstantSDNode *ConstOp;
12725+
if ((ConstOp = isConstOrConstSplat(Op1)) && ConstOp->isOne()) {
12726+
ExtendOpA = Op2;
12727+
ExtendOpB = Op3;
12728+
return true;
12729+
}
12730+
if ((ConstOp = isConstOrConstSplat(Op2)) && ConstOp->isOne()) {
12731+
ExtendOpA = Op1;
12732+
ExtendOpB = Op3;
12733+
return true;
12734+
}
12735+
if ((ConstOp = isConstOrConstSplat(Op3)) && ConstOp->isOne()) {
12736+
ExtendOpA = Op1;
12737+
ExtendOpB = Op2;
12738+
return true;
12739+
}
12740+
return false;
12741+
};
12742+
bool IsCeil = (ExtendOpA.getOpcode() == ISD::ADD &&
12743+
MatchOperands(ExtendOpA.getOperand(0), ExtendOpA.getOperand(1),
12744+
ExtendOpB)) ||
12745+
(ExtendOpB.getOpcode() == ISD::ADD &&
12746+
MatchOperands(ExtendOpB.getOperand(0), ExtendOpB.getOperand(1),
12747+
ExtendOpA));
12748+
12749+
unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
12750+
unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
12751+
if (!(ExtendOpAOpc == ExtendOpBOpc &&
12752+
(ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
12753+
return SDValue();
12754+
12755+
// Is the result of the right shift being truncated to the same value type as
12756+
// the original operands, OpA and OpB?
12757+
SDValue OpA = ExtendOpA.getOperand(0);
12758+
SDValue OpB = ExtendOpB.getOperand(0);
12759+
EVT OpAVT = OpA.getValueType();
12760+
if (VT != OpAVT || OpAVT != OpB.getValueType())
12761+
return SDValue();
12762+
12763+
bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
12764+
unsigned AVGOpc = IsSignExtend ? (IsCeil ? ISD::AVGCEILS : ISD::AVGFLOORS)
12765+
: (IsCeil ? ISD::AVGCEILU : ISD::AVGFLOORU);
12766+
if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(AVGOpc, VT))
12767+
return SDValue();
12768+
12769+
return DAG.getNode(AVGOpc, SDLoc(N), VT, OpA, OpB);
12770+
}
12771+
1269112772
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1269212773
SDValue N0 = N->getOperand(0);
1269312774
EVT VT = N->getValueType(0);
@@ -12974,6 +13055,8 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1297413055

1297513056
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
1297613057
return NewVSel;
13058+
if (SDValue M = performAvgCombine(N, DAG))
13059+
return M;
1297713060

1297813061
// Narrow a suitable binary operation with a non-opaque constant operand by
1297913062
// moving it ahead of the truncate. This is limited to pre-legalization

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,6 +3287,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
32873287
case ISD::USHLSAT:
32883288
case ISD::ROTL:
32893289
case ISD::ROTR:
3290+
case ISD::AVGFLOORS:
3291+
case ISD::AVGFLOORU:
3292+
case ISD::AVGCEILS:
3293+
case ISD::AVGCEILU:
32903294
// Vector-predicated binary op widening. Note that -- unlike the
32913295
// unpredicated versions -- we don't have to worry about trapping on
32923296
// operations like UDIV, FADD, etc., as we pass on the original vector

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
231231
case ISD::MUL: return "mul";
232232
case ISD::MULHU: return "mulhu";
233233
case ISD::MULHS: return "mulhs";
234+
case ISD::AVGFLOORU: return "avgflooru";
235+
case ISD::AVGFLOORS: return "avgfloors";
236+
case ISD::AVGCEILU: return "avgceilu";
237+
case ISD::AVGCEILS: return "avgceils";
234238
case ISD::ABDS: return "abds";
235239
case ISD::ABDU: return "abdu";
236240
case ISD::SDIV: return "sdiv";

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,12 @@ void TargetLoweringBase::initActions() {
817817
setOperationAction(ISD::SUBC, VT, Expand);
818818
setOperationAction(ISD::SUBE, VT, Expand);
819819

820+
// Halving adds
821+
setOperationAction(ISD::AVGFLOORS, VT, Expand);
822+
setOperationAction(ISD::AVGFLOORU, VT, Expand);
823+
setOperationAction(ISD::AVGCEILS, VT, Expand);
824+
setOperationAction(ISD::AVGCEILU, VT, Expand);
825+
820826
// Absolute difference
821827
setOperationAction(ISD::ABDS, VT, Expand);
822828
setOperationAction(ISD::ABDU, VT, Expand);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 19 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
870870
setTargetDAGCombine(ISD::SIGN_EXTEND);
871871
setTargetDAGCombine(ISD::VECTOR_SPLICE);
872872
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
873-
setTargetDAGCombine(ISD::TRUNCATE);
874873
setTargetDAGCombine(ISD::CONCAT_VECTORS);
875874
setTargetDAGCombine(ISD::INSERT_SUBVECTOR);
876875
setTargetDAGCombine(ISD::STORE);
@@ -1047,6 +1046,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
10471046

10481047
for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
10491048
MVT::v4i32}) {
1049+
setOperationAction(ISD::AVGFLOORS, VT, Legal);
1050+
setOperationAction(ISD::AVGFLOORU, VT, Legal);
1051+
setOperationAction(ISD::AVGCEILS, VT, Legal);
1052+
setOperationAction(ISD::AVGCEILU, VT, Legal);
10501053
setOperationAction(ISD::ABDS, VT, Legal);
10511054
setOperationAction(ISD::ABDU, VT, Legal);
10521055
}
@@ -2096,10 +2099,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
20962099
MAKE_CASE(AArch64ISD::FCMLTz)
20972100
MAKE_CASE(AArch64ISD::SADDV)
20982101
MAKE_CASE(AArch64ISD::UADDV)
2099-
MAKE_CASE(AArch64ISD::SRHADD)
2100-
MAKE_CASE(AArch64ISD::URHADD)
2101-
MAKE_CASE(AArch64ISD::SHADD)
2102-
MAKE_CASE(AArch64ISD::UHADD)
21032102
MAKE_CASE(AArch64ISD::SDOT)
21042103
MAKE_CASE(AArch64ISD::UDOT)
21052104
MAKE_CASE(AArch64ISD::SMINV)
@@ -4371,9 +4370,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
43714370
IntNo == Intrinsic::aarch64_neon_shadd);
43724371
bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
43734372
IntNo == Intrinsic::aarch64_neon_urhadd);
4374-
unsigned Opcode =
4375-
IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
4376-
: (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD);
4373+
unsigned Opcode = IsSignedAdd
4374+
? (IsRoundingAdd ? ISD::AVGCEILS : ISD::AVGFLOORS)
4375+
: (IsRoundingAdd ? ISD::AVGCEILU : ISD::AVGFLOORU);
43774376
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
43784377
Op.getOperand(2));
43794378
}
@@ -14243,89 +14242,6 @@ static SDValue performANDCombine(SDNode *N,
1424314242
return SDValue();
1424414243
}
1424514244

14246-
// Attempt to form urhadd(OpA, OpB) from
14247-
// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1))
14248-
// or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)).
14249-
// The original form of the first expression is
14250-
// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the
14251-
// (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)).
14252-
// Before this function is called the srl will have been lowered to
14253-
// AArch64ISD::VLSHR.
14254-
// This pass can also recognize signed variants of the patterns that use sign
14255-
// extension instead of zero extension and form a srhadd(OpA, OpB) or a
14256-
// shadd(OpA, OpB) from them.
14257-
static SDValue
14258-
performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
14259-
SelectionDAG &DAG) {
14260-
EVT VT = N->getValueType(0);
14261-
14262-
// Since we are looking for a right shift by a constant value of 1 and we are
14263-
// operating on types at least 16 bits in length (sign/zero extended OpA and
14264-
// OpB, which are at least 8 bits), it follows that the truncate will always
14265-
// discard the shifted-in bit and therefore the right shift will be logical
14266-
// regardless of the signedness of OpA and OpB.
14267-
SDValue Shift = N->getOperand(0);
14268-
if (Shift.getOpcode() != AArch64ISD::VLSHR)
14269-
return SDValue();
14270-
14271-
// Is the right shift using an immediate value of 1?
14272-
uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
14273-
if (ShiftAmount != 1)
14274-
return SDValue();
14275-
14276-
SDValue ExtendOpA, ExtendOpB;
14277-
SDValue ShiftOp0 = Shift.getOperand(0);
14278-
unsigned ShiftOp0Opc = ShiftOp0.getOpcode();
14279-
if (ShiftOp0Opc == ISD::SUB) {
14280-
14281-
SDValue Xor = ShiftOp0.getOperand(1);
14282-
if (Xor.getOpcode() != ISD::XOR)
14283-
return SDValue();
14284-
14285-
// Is the XOR using a constant amount of all ones in the right hand side?
14286-
uint64_t C;
14287-
if (!isAllConstantBuildVector(Xor.getOperand(1), C))
14288-
return SDValue();
14289-
14290-
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
14291-
APInt CAsAPInt(ElemSizeInBits, C);
14292-
if (CAsAPInt != APInt::getAllOnes(ElemSizeInBits))
14293-
return SDValue();
14294-
14295-
ExtendOpA = Xor.getOperand(0);
14296-
ExtendOpB = ShiftOp0.getOperand(0);
14297-
} else if (ShiftOp0Opc == ISD::ADD) {
14298-
ExtendOpA = ShiftOp0.getOperand(0);
14299-
ExtendOpB = ShiftOp0.getOperand(1);
14300-
} else
14301-
return SDValue();
14302-
14303-
unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
14304-
unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
14305-
if (!(ExtendOpAOpc == ExtendOpBOpc &&
14306-
(ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
14307-
return SDValue();
14308-
14309-
// Is the result of the right shift being truncated to the same value type as
14310-
// the original operands, OpA and OpB?
14311-
SDValue OpA = ExtendOpA.getOperand(0);
14312-
SDValue OpB = ExtendOpB.getOperand(0);
14313-
EVT OpAVT = OpA.getValueType();
14314-
assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
14315-
if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
14316-
return SDValue();
14317-
14318-
SDLoc DL(N);
14319-
bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
14320-
bool IsRHADD = ShiftOp0Opc == ISD::SUB;
14321-
unsigned HADDOpc = IsSignExtend
14322-
? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
14323-
: (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD);
14324-
SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB);
14325-
14326-
return ResultHADD;
14327-
}
14328-
1432914245
static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
1433014246
switch (Opcode) {
1433114247
case ISD::FADD:
@@ -14428,20 +14344,20 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1442814344
if (DCI.isBeforeLegalizeOps())
1442914345
return SDValue();
1443014346

14431-
// Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted
14432-
// subvectors from the same original vectors. Combine these into a single
14433-
// [us]rhadd or [us]hadd that operates on the two original vectors. Example:
14434-
// (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>),
14435-
// extract_subvector (v16i8 OpB,
14436-
// <0>))),
14437-
// (v8i8 (urhadd (extract_subvector (v16i8 OpA, <8>),
14438-
// extract_subvector (v16i8 OpB,
14439-
// <8>)))))
14347+
// Optimise concat_vectors of two [us]avgceils or [us]avgfloors that use
14348+
// extracted subvectors from the same original vectors. Combine these into a
14349+
// single avg that operates on the two original vectors.
14350+
// avgceil is the target independant name for rhadd, avgfloor is a hadd.
14351+
// Example:
14352+
// (concat_vectors (v8i8 (avgceils (extract_subvector (v16i8 OpA, <0>),
14353+
// extract_subvector (v16i8 OpB, <0>))),
14354+
// (v8i8 (avgceils (extract_subvector (v16i8 OpA, <8>),
14355+
// extract_subvector (v16i8 OpB, <8>)))))
1444014356
// ->
14441-
// (v16i8(urhadd(v16i8 OpA, v16i8 OpB)))
14357+
// (v16i8(avgceils(v16i8 OpA, v16i8 OpB)))
1444214358
if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
14443-
(N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD ||
14444-
N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) {
14359+
(N0Opc == ISD::AVGCEILU || N0Opc == ISD::AVGCEILS ||
14360+
N0Opc == ISD::AVGFLOORU || N0Opc == ISD::AVGFLOORS)) {
1444514361
SDValue N00 = N0->getOperand(0);
1444614362
SDValue N01 = N0->getOperand(1);
1444714363
SDValue N10 = N1->getOperand(0);
@@ -18022,8 +17938,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
1802217938
return performExtendCombine(N, DCI, DAG);
1802317939
case ISD::SIGN_EXTEND_INREG:
1802417940
return performSignExtendInRegCombine(N, DCI, DAG);
18025-
case ISD::TRUNCATE:
18026-
return performVectorTruncateCombine(N, DCI, DAG);
1802717941
case ISD::CONCAT_VECTORS:
1802817942
return performConcatVectorsCombine(N, DCI, DAG);
1802917943
case ISD::INSERT_SUBVECTOR:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,6 @@ enum NodeType : unsigned {
230230
SADDV,
231231
UADDV,
232232

233-
// Vector halving addition
234-
SHADD,
235-
UHADD,
236-
237-
// Vector rounding halving addition
238-
SRHADD,
239-
URHADD,
240-
241233
// Add Long Pairwise
242234
SADDLP,
243235
UADDLP,

0 commit comments

Comments
 (0)