@@ -838,6 +838,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
838
838
setOperationAction(ISD::UADDSAT, VT, Legal);
839
839
setOperationAction(ISD::SSUBSAT, VT, Legal);
840
840
setOperationAction(ISD::USUBSAT, VT, Legal);
841
+
842
+ setOperationAction(ISD::TRUNCATE, VT, Custom);
841
843
}
842
844
for (MVT VT : { MVT::v4f16, MVT::v2f32,
843
845
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@@ -1432,6 +1434,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
1432
1434
MAKE_CASE(AArch64ISD::FCMLTz)
1433
1435
MAKE_CASE(AArch64ISD::SADDV)
1434
1436
MAKE_CASE(AArch64ISD::UADDV)
1437
+ MAKE_CASE(AArch64ISD::SRHADD)
1438
+ MAKE_CASE(AArch64ISD::URHADD)
1435
1439
MAKE_CASE(AArch64ISD::SMINV)
1436
1440
MAKE_CASE(AArch64ISD::UMINV)
1437
1441
MAKE_CASE(AArch64ISD::SMAXV)
@@ -3260,6 +3264,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
3260
3264
return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
3261
3265
Op.getOperand(3));
3262
3266
}
3267
+
3268
+ case Intrinsic::aarch64_neon_srhadd:
3269
+ case Intrinsic::aarch64_neon_urhadd: {
3270
+ bool IsSignedAdd = IntNo == Intrinsic::aarch64_neon_srhadd;
3271
+ unsigned Opcode = IsSignedAdd ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
3272
+ return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
3273
+ Op.getOperand(2));
3274
+ }
3263
3275
}
3264
3276
}
3265
3277
@@ -3524,6 +3536,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
3524
3536
return LowerDYNAMIC_STACKALLOC(Op, DAG);
3525
3537
case ISD::VSCALE:
3526
3538
return LowerVSCALE(Op, DAG);
3539
+ case ISD::TRUNCATE:
3540
+ return LowerTRUNCATE(Op, DAG);
3527
3541
case ISD::LOAD:
3528
3542
if (useSVEForFixedLengthVectorVT(Op.getValueType()))
3529
3543
return LowerFixedLengthVectorLoadToSVE(Op, DAG);
@@ -8773,6 +8787,78 @@ static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
8773
8787
return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
8774
8788
}
8775
8789
8790
+ // Attempt to form urhadd(OpA, OpB) from
8791
+ // truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)).
8792
+ // The original form of this expression is
8793
+ // truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and before this function
8794
+ // is called the srl will have been lowered to AArch64ISD::VLSHR and the
8795
+ // ((OpA + OpB + 1) >> 1) expression will have been changed to (OpB - (~OpA)).
8796
+ // This pass can also recognize a variant of this pattern that uses sign
8797
+ // extension instead of zero extension and form a srhadd(OpA, OpB) from it.
8798
+ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
8799
+ SelectionDAG &DAG) const {
8800
+ EVT VT = Op.getValueType();
8801
+
8802
+ if (!VT.isVector() || VT.isScalableVector())
8803
+ return Op;
8804
+
8805
+ // Since we are looking for a right shift by a constant value of 1 and we are
8806
+ // operating on types at least 16 bits in length (sign/zero extended OpA and
8807
+ // OpB, which are at least 8 bits), it follows that the truncate will always
8808
+ // discard the shifted-in bit and therefore the right shift will be logical
8809
+ // regardless of the signedness of OpA and OpB.
8810
+ SDValue Shift = Op.getOperand(0);
8811
+ if (Shift.getOpcode() != AArch64ISD::VLSHR)
8812
+ return Op;
8813
+
8814
+ // Is the right shift using an immediate value of 1?
8815
+ uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
8816
+ if (ShiftAmount != 1)
8817
+ return Op;
8818
+
8819
+ SDValue Sub = Shift->getOperand(0);
8820
+ if (Sub.getOpcode() != ISD::SUB)
8821
+ return Op;
8822
+
8823
+ SDValue Xor = Sub.getOperand(1);
8824
+ if (Xor.getOpcode() != ISD::XOR)
8825
+ return Op;
8826
+
8827
+ SDValue ExtendOpA = Xor.getOperand(0);
8828
+ SDValue ExtendOpB = Sub.getOperand(0);
8829
+ unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
8830
+ unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
8831
+ if (!(ExtendOpAOpc == ExtendOpBOpc &&
8832
+ (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
8833
+ return Op;
8834
+
8835
+ // Is the result of the right shift being truncated to the same value type as
8836
+ // the original operands, OpA and OpB?
8837
+ SDValue OpA = ExtendOpA.getOperand(0);
8838
+ SDValue OpB = ExtendOpB.getOperand(0);
8839
+ EVT OpAVT = OpA.getValueType();
8840
+ assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
8841
+ if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
8842
+ return Op;
8843
+
8844
+ // Is the XOR using a constant amount of all ones in the right hand side?
8845
+ uint64_t C;
8846
+ if (!isAllConstantBuildVector(Xor.getOperand(1), C))
8847
+ return Op;
8848
+
8849
+ unsigned ElemSizeInBits = VT.getScalarSizeInBits();
8850
+ APInt CAsAPInt(ElemSizeInBits, C);
8851
+ if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits))
8852
+ return Op;
8853
+
8854
+ SDLoc DL(Op);
8855
+ bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
8856
+ unsigned RHADDOpc = IsSignExtend ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
8857
+ SDValue ResultURHADD = DAG.getNode(RHADDOpc, DL, VT, OpA, OpB);
8858
+
8859
+ return ResultURHADD;
8860
+ }
8861
+
8776
8862
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
8777
8863
SelectionDAG &DAG) const {
8778
8864
EVT VT = Op.getValueType();
@@ -10982,6 +11068,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
10982
11068
SDLoc dl(N);
10983
11069
EVT VT = N->getValueType(0);
10984
11070
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
11071
+ unsigned N0Opc = N0->getOpcode(), N1Opc = N1->getOpcode();
10985
11072
10986
11073
// Optimize concat_vectors of truncated vectors, where the intermediate
10987
11074
// type is illegal, to avoid said illegality, e.g.,
@@ -10994,9 +11081,8 @@ static SDValue performConcatVectorsCombine(SDNode *N,
10994
11081
// This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
10995
11082
// on both input and result type, so we might generate worse code.
10996
11083
// On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
10997
- if (N->getNumOperands() == 2 &&
10998
- N0->getOpcode() == ISD::TRUNCATE &&
10999
- N1->getOpcode() == ISD::TRUNCATE) {
11084
+ if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
11085
+ N1Opc == ISD::TRUNCATE) {
11000
11086
SDValue N00 = N0->getOperand(0);
11001
11087
SDValue N10 = N1->getOperand(0);
11002
11088
EVT N00VT = N00.getValueType();
@@ -11021,6 +11107,52 @@ static SDValue performConcatVectorsCombine(SDNode *N,
11021
11107
if (DCI.isBeforeLegalizeOps())
11022
11108
return SDValue();
11023
11109
11110
+ // Optimise concat_vectors of two [us]rhadds that use extracted subvectors
11111
+ // from the same original vectors. Combine these into a single [us]rhadd that
11112
+ // operates on the two original vectors. Example:
11113
+ // (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>),
11114
+ // extract_subvector (v16i8 OpB,
11115
+ // <0>))),
11116
+ // (v8i8 (urhadd (extract_subvector (v16i8 OpA, <8>),
11117
+ // extract_subvector (v16i8 OpB,
11118
+ // <8>)))))
11119
+ // ->
11120
+ // (v16i8(urhadd(v16i8 OpA, v16i8 OpB)))
11121
+ if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
11122
+ (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD)) {
11123
+ SDValue N00 = N0->getOperand(0);
11124
+ SDValue N01 = N0->getOperand(1);
11125
+ SDValue N10 = N1->getOperand(0);
11126
+ SDValue N11 = N1->getOperand(1);
11127
+
11128
+ EVT N00VT = N00.getValueType();
11129
+ EVT N10VT = N10.getValueType();
11130
+
11131
+ if (N00->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11132
+ N01->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11133
+ N10->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
11134
+ N11->getOpcode() == ISD::EXTRACT_SUBVECTOR && N00VT == N10VT) {
11135
+ SDValue N00Source = N00->getOperand(0);
11136
+ SDValue N01Source = N01->getOperand(0);
11137
+ SDValue N10Source = N10->getOperand(0);
11138
+ SDValue N11Source = N11->getOperand(0);
11139
+
11140
+ if (N00Source == N10Source && N01Source == N11Source &&
11141
+ N00Source.getValueType() == VT && N01Source.getValueType() == VT) {
11142
+ assert(N0.getValueType() == N1.getValueType());
11143
+
11144
+ uint64_t N00Index = N00.getConstantOperandVal(1);
11145
+ uint64_t N01Index = N01.getConstantOperandVal(1);
11146
+ uint64_t N10Index = N10.getConstantOperandVal(1);
11147
+ uint64_t N11Index = N11.getConstantOperandVal(1);
11148
+
11149
+ if (N00Index == N01Index && N10Index == N11Index && N00Index == 0 &&
11150
+ N10Index == N00VT.getVectorNumElements())
11151
+ return DAG.getNode(N0Opc, dl, VT, N00Source, N01Source);
11152
+ }
11153
+ }
11154
+ }
11155
+
11024
11156
// If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
11025
11157
// splat. The indexed instructions are going to be expecting a DUPLANE64, so
11026
11158
// canonicalise to that.
@@ -11039,7 +11171,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
11039
11171
// becomes
11040
11172
// (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
11041
11173
11042
- if (N1->getOpcode() != ISD::BITCAST)
11174
+ if (N1Opc != ISD::BITCAST)
11043
11175
return SDValue();
11044
11176
SDValue RHS = N1->getOperand(0);
11045
11177
MVT RHSTy = RHS.getValueType().getSimpleVT();
0 commit comments