Skip to content

Commit af80a43

Browse files
[ARM] Generate [SU]RHADD from (b - (~a)) >> 1
Summary: Teach LLVM to recognize the above pattern, which is usually a transformation of (a + b + 1) >> 1, where the operands are either signed or unsigned types. Subscribers: kristof.beyls, hiraditya, danielkiss, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D82669
1 parent b3b9528 commit af80a43

File tree

4 files changed

+483
-54
lines changed

4 files changed

+483
-54
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
838838
setOperationAction(ISD::UADDSAT, VT, Legal);
839839
setOperationAction(ISD::SSUBSAT, VT, Legal);
840840
setOperationAction(ISD::USUBSAT, VT, Legal);
841+
842+
setOperationAction(ISD::TRUNCATE, VT, Custom);
841843
}
842844
for (MVT VT : { MVT::v4f16, MVT::v2f32,
843845
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@@ -1432,6 +1434,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
14321434
MAKE_CASE(AArch64ISD::FCMLTz)
14331435
MAKE_CASE(AArch64ISD::SADDV)
14341436
MAKE_CASE(AArch64ISD::UADDV)
1437+
MAKE_CASE(AArch64ISD::SRHADD)
1438+
MAKE_CASE(AArch64ISD::URHADD)
14351439
MAKE_CASE(AArch64ISD::SMINV)
14361440
MAKE_CASE(AArch64ISD::UMINV)
14371441
MAKE_CASE(AArch64ISD::SMAXV)
@@ -3260,6 +3264,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
32603264
return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
32613265
Op.getOperand(3));
32623266
}
3267+
3268+
case Intrinsic::aarch64_neon_srhadd:
3269+
case Intrinsic::aarch64_neon_urhadd: {
3270+
bool IsSignedAdd = IntNo == Intrinsic::aarch64_neon_srhadd;
3271+
unsigned Opcode = IsSignedAdd ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
3272+
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
3273+
Op.getOperand(2));
3274+
}
32633275
}
32643276
}
32653277

@@ -3524,6 +3536,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
35243536
return LowerDYNAMIC_STACKALLOC(Op, DAG);
35253537
case ISD::VSCALE:
35263538
return LowerVSCALE(Op, DAG);
3539+
case ISD::TRUNCATE:
3540+
return LowerTRUNCATE(Op, DAG);
35273541
case ISD::LOAD:
35283542
if (useSVEForFixedLengthVectorVT(Op.getValueType()))
35293543
return LowerFixedLengthVectorLoadToSVE(Op, DAG);
@@ -8773,6 +8787,78 @@ static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
87738787
return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
87748788
}
87758789

8790+
// Attempt to form urhadd(OpA, OpB) from
8791+
// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)).
8792+
// The original form of this expression is
8793+
// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and before this function
8794+
// is called the srl will have been lowered to AArch64ISD::VLSHR and the
8795+
// ((OpA + OpB + 1) >> 1) expression will have been changed to (OpB - (~OpA)).
8796+
// This pass can also recognize a variant of this pattern that uses sign
8797+
// extension instead of zero extension and form a srhadd(OpA, OpB) from it.
8798+
SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
8799+
SelectionDAG &DAG) const {
8800+
EVT VT = Op.getValueType();
8801+
8802+
if (!VT.isVector() || VT.isScalableVector())
8803+
return Op;
8804+
8805+
// Since we are looking for a right shift by a constant value of 1 and we are
8806+
// operating on types at least 16 bits in length (sign/zero extended OpA and
8807+
// OpB, which are at least 8 bits), it follows that the truncate will always
8808+
// discard the shifted-in bit and therefore the right shift will be logical
8809+
// regardless of the signedness of OpA and OpB.
8810+
SDValue Shift = Op.getOperand(0);
8811+
if (Shift.getOpcode() != AArch64ISD::VLSHR)
8812+
return Op;
8813+
8814+
// Is the right shift using an immediate value of 1?
8815+
uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
8816+
if (ShiftAmount != 1)
8817+
return Op;
8818+
8819+
SDValue Sub = Shift->getOperand(0);
8820+
if (Sub.getOpcode() != ISD::SUB)
8821+
return Op;
8822+
8823+
SDValue Xor = Sub.getOperand(1);
8824+
if (Xor.getOpcode() != ISD::XOR)
8825+
return Op;
8826+
8827+
SDValue ExtendOpA = Xor.getOperand(0);
8828+
SDValue ExtendOpB = Sub.getOperand(0);
8829+
unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
8830+
unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
8831+
if (!(ExtendOpAOpc == ExtendOpBOpc &&
8832+
(ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
8833+
return Op;
8834+
8835+
// Is the result of the right shift being truncated to the same value type as
8836+
// the original operands, OpA and OpB?
8837+
SDValue OpA = ExtendOpA.getOperand(0);
8838+
SDValue OpB = ExtendOpB.getOperand(0);
8839+
EVT OpAVT = OpA.getValueType();
8840+
assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
8841+
if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
8842+
return Op;
8843+
8844+
// Is the XOR using a constant amount of all ones in the right hand side?
8845+
uint64_t C;
8846+
if (!isAllConstantBuildVector(Xor.getOperand(1), C))
8847+
return Op;
8848+
8849+
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
8850+
APInt CAsAPInt(ElemSizeInBits, C);
8851+
if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits))
8852+
return Op;
8853+
8854+
SDLoc DL(Op);
8855+
bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
8856+
unsigned RHADDOpc = IsSignExtend ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
8857+
SDValue ResultURHADD = DAG.getNode(RHADDOpc, DL, VT, OpA, OpB);
8858+
8859+
return ResultURHADD;
8860+
}
8861+
87768862
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
87778863
SelectionDAG &DAG) const {
87788864
EVT VT = Op.getValueType();
@@ -10982,6 +11068,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1098211068
SDLoc dl(N);
1098311069
EVT VT = N->getValueType(0);
1098411070
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11071+
unsigned N0Opc = N0->getOpcode(), N1Opc = N1->getOpcode();
1098511072

1098611073
// Optimize concat_vectors of truncated vectors, where the intermediate
1098711074
// type is illegal, to avoid said illegality, e.g.,
@@ -10994,9 +11081,8 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1099411081
// This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
1099511082
// on both input and result type, so we might generate worse code.
1099611083
// On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
10997-
if (N->getNumOperands() == 2 &&
10998-
N0->getOpcode() == ISD::TRUNCATE &&
10999-
N1->getOpcode() == ISD::TRUNCATE) {
11084+
if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
11085+
N1Opc == ISD::TRUNCATE) {
1100011086
SDValue N00 = N0->getOperand(0);
1100111087
SDValue N10 = N1->getOperand(0);
1100211088
EVT N00VT = N00.getValueType();
@@ -11021,6 +11107,52 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1102111107
if (DCI.isBeforeLegalizeOps())
1102211108
return SDValue();
1102311109

11110+
// Optimise concat_vectors of two [us]rhadds that use extracted subvectors
11111+
// from the same original vectors. Combine these into a single [us]rhadd that
11112+
// operates on the two original vectors. Example:
11113+
// (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>),
11114+
// extract_subvector (v16i8 OpB,
11115+
// <0>))),
11116+
// (v8i8 (urhadd (extract_subvector (v16i8 OpA, <8>),
11117+
// extract_subvector (v16i8 OpB,
11118+
// <8>)))))
11119+
// ->
11120+
// (v16i8(urhadd(v16i8 OpA, v16i8 OpB)))
11121+
if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
11122+
(N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD)) {
11123+
SDValue N00 = N0->getOperand(0);
11124+
SDValue N01 = N0->getOperand(1);
11125+
SDValue N10 = N1->getOperand(0);
11126+
SDValue N11 = N1->getOperand(1);
11127+
11128+
EVT N00VT = N00.getValueType();
11129+
EVT N10VT = N10.getValueType();
11130+
11131+
if (N00->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11132+
N01->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11133+
N10->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11134+
N11->getOpcode() == ISD::EXTRACT_SUBVECTOR && N00VT == N10VT) {
11135+
SDValue N00Source = N00->getOperand(0);
11136+
SDValue N01Source = N01->getOperand(0);
11137+
SDValue N10Source = N10->getOperand(0);
11138+
SDValue N11Source = N11->getOperand(0);
11139+
11140+
if (N00Source == N10Source && N01Source == N11Source &&
11141+
N00Source.getValueType() == VT && N01Source.getValueType() == VT) {
11142+
assert(N0.getValueType() == N1.getValueType());
11143+
11144+
uint64_t N00Index = N00.getConstantOperandVal(1);
11145+
uint64_t N01Index = N01.getConstantOperandVal(1);
11146+
uint64_t N10Index = N10.getConstantOperandVal(1);
11147+
uint64_t N11Index = N11.getConstantOperandVal(1);
11148+
11149+
if (N00Index == N01Index && N10Index == N11Index && N00Index == 0 &&
11150+
N10Index == N00VT.getVectorNumElements())
11151+
return DAG.getNode(N0Opc, dl, VT, N00Source, N01Source);
11152+
}
11153+
}
11154+
}
11155+
1102411156
// If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
1102511157
// splat. The indexed instructions are going to be expecting a DUPLANE64, so
1102611158
// canonicalise to that.
@@ -11039,7 +11171,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
1103911171
// becomes
1104011172
// (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
1104111173

11042-
if (N1->getOpcode() != ISD::BITCAST)
11174+
if (N1Opc != ISD::BITCAST)
1104311175
return SDValue();
1104411176
SDValue RHS = N1->getOperand(0);
1104511177
MVT RHSTy = RHS.getValueType().getSimpleVT();

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ enum NodeType : unsigned {
187187
SADDV,
188188
UADDV,
189189

190+
// Vector rounding halving addition
191+
SRHADD,
192+
URHADD,
193+
190194
// Vector across-lanes min/max
191195
// Only the lower result lane is defined.
192196
SMINV,
@@ -863,6 +867,7 @@ class AArch64TargetLowering : public TargetLowering {
863867
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
864868
SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const;
865869
SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const;
870+
SDValue LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const;
866871
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
867872
SDValue LowerATOMIC_LOAD_SUB(SDValue Op, SelectionDAG &DAG) const;
868873
SDValue LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,9 @@ def AArch64uminv : SDNode<"AArch64ISD::UMINV", SDT_AArch64UnaryVec>;
554554
def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>;
555555
def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>;
556556

557+
def AArch64srhadd : SDNode<"AArch64ISD::SRHADD", SDT_AArch64binvec>;
558+
def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
559+
557560
def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>;
558561
def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
559562
def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
@@ -4073,7 +4076,7 @@ defm SQRDMULH : SIMDThreeSameVectorHS<1,0b10110,"sqrdmulh",int_aarch64_neon_sqrd
40734076
defm SQRSHL : SIMDThreeSameVector<0,0b01011,"sqrshl", int_aarch64_neon_sqrshl>;
40744077
defm SQSHL : SIMDThreeSameVector<0,0b01001,"sqshl", int_aarch64_neon_sqshl>;
40754078
defm SQSUB : SIMDThreeSameVector<0,0b00101,"sqsub", int_aarch64_neon_sqsub>;
4076-
defm SRHADD : SIMDThreeSameVectorBHS<0,0b00010,"srhadd",int_aarch64_neon_srhadd>;
4079+
defm SRHADD : SIMDThreeSameVectorBHS<0,0b00010,"srhadd", AArch64srhadd>;
40774080
defm SRSHL : SIMDThreeSameVector<0,0b01010,"srshl", int_aarch64_neon_srshl>;
40784081
defm SSHL : SIMDThreeSameVector<0,0b01000,"sshl", int_aarch64_neon_sshl>;
40794082
defm SUB : SIMDThreeSameVector<1,0b10000,"sub", sub>;
@@ -4090,7 +4093,7 @@ defm UQADD : SIMDThreeSameVector<1,0b00001,"uqadd", int_aarch64_neon_uqadd>;
40904093
defm UQRSHL : SIMDThreeSameVector<1,0b01011,"uqrshl", int_aarch64_neon_uqrshl>;
40914094
defm UQSHL : SIMDThreeSameVector<1,0b01001,"uqshl", int_aarch64_neon_uqshl>;
40924095
defm UQSUB : SIMDThreeSameVector<1,0b00101,"uqsub", int_aarch64_neon_uqsub>;
4093-
defm URHADD : SIMDThreeSameVectorBHS<1,0b00010,"urhadd", int_aarch64_neon_urhadd>;
4096+
defm URHADD : SIMDThreeSameVectorBHS<1,0b00010,"urhadd", AArch64urhadd>;
40944097
defm URSHL : SIMDThreeSameVector<1,0b01010,"urshl", int_aarch64_neon_urshl>;
40954098
defm USHL : SIMDThreeSameVector<1,0b01000,"ushl", int_aarch64_neon_ushl>;
40964099
defm SQRDMLAH : SIMDThreeSameVectorSQRDMLxHTiedHS<1,0b10000,"sqrdmlah",

0 commit comments

Comments
 (0)