@@ -1013,6 +1013,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1013
1013
setOperationAction(ISD::SHL, VT, Custom);
1014
1014
setOperationAction(ISD::SRL, VT, Custom);
1015
1015
setOperationAction(ISD::SRA, VT, Custom);
1016
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1017
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1018
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1019
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1020
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1021
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1022
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1023
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1016
1024
}
1017
1025
1018
1026
// Illegal unpacked integer vector types.
@@ -1027,6 +1035,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1027
1035
setOperationAction(ISD::SETCC, VT, Custom);
1028
1036
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1029
1037
setOperationAction(ISD::TRUNCATE, VT, Custom);
1038
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1039
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1040
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1030
1041
1031
1042
// There are no legal MVT::nxv16f## based types.
1032
1043
if (VT != MVT::nxv16i1) {
@@ -9815,30 +9826,35 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
9815
9826
Op.getOpcode() == ISD::VECREDUCE_FADD ||
9816
9827
(Op.getOpcode() != ISD::VECREDUCE_ADD &&
9817
9828
SrcVT.getVectorElementType() == MVT::i64);
9818
- if (useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
9829
+ if (SrcVT.isScalableVector() ||
9830
+ useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
9831
+
9832
+ if (SrcVT.getVectorElementType() == MVT::i1)
9833
+ return LowerPredReductionToSVE(Op, DAG);
9834
+
9819
9835
switch (Op.getOpcode()) {
9820
9836
case ISD::VECREDUCE_ADD:
9821
- return LowerFixedLengthReductionToSVE (AArch64ISD::UADDV_PRED, Op, DAG);
9837
+ return LowerReductionToSVE (AArch64ISD::UADDV_PRED, Op, DAG);
9822
9838
case ISD::VECREDUCE_AND:
9823
- return LowerFixedLengthReductionToSVE (AArch64ISD::ANDV_PRED, Op, DAG);
9839
+ return LowerReductionToSVE (AArch64ISD::ANDV_PRED, Op, DAG);
9824
9840
case ISD::VECREDUCE_OR:
9825
- return LowerFixedLengthReductionToSVE (AArch64ISD::ORV_PRED, Op, DAG);
9841
+ return LowerReductionToSVE (AArch64ISD::ORV_PRED, Op, DAG);
9826
9842
case ISD::VECREDUCE_SMAX:
9827
- return LowerFixedLengthReductionToSVE (AArch64ISD::SMAXV_PRED, Op, DAG);
9843
+ return LowerReductionToSVE (AArch64ISD::SMAXV_PRED, Op, DAG);
9828
9844
case ISD::VECREDUCE_SMIN:
9829
- return LowerFixedLengthReductionToSVE (AArch64ISD::SMINV_PRED, Op, DAG);
9845
+ return LowerReductionToSVE (AArch64ISD::SMINV_PRED, Op, DAG);
9830
9846
case ISD::VECREDUCE_UMAX:
9831
- return LowerFixedLengthReductionToSVE (AArch64ISD::UMAXV_PRED, Op, DAG);
9847
+ return LowerReductionToSVE (AArch64ISD::UMAXV_PRED, Op, DAG);
9832
9848
case ISD::VECREDUCE_UMIN:
9833
- return LowerFixedLengthReductionToSVE (AArch64ISD::UMINV_PRED, Op, DAG);
9849
+ return LowerReductionToSVE (AArch64ISD::UMINV_PRED, Op, DAG);
9834
9850
case ISD::VECREDUCE_XOR:
9835
- return LowerFixedLengthReductionToSVE (AArch64ISD::EORV_PRED, Op, DAG);
9851
+ return LowerReductionToSVE (AArch64ISD::EORV_PRED, Op, DAG);
9836
9852
case ISD::VECREDUCE_FADD:
9837
- return LowerFixedLengthReductionToSVE (AArch64ISD::FADDV_PRED, Op, DAG);
9853
+ return LowerReductionToSVE (AArch64ISD::FADDV_PRED, Op, DAG);
9838
9854
case ISD::VECREDUCE_FMAX:
9839
- return LowerFixedLengthReductionToSVE (AArch64ISD::FMAXNMV_PRED, Op, DAG);
9855
+ return LowerReductionToSVE (AArch64ISD::FMAXNMV_PRED, Op, DAG);
9840
9856
case ISD::VECREDUCE_FMIN:
9841
- return LowerFixedLengthReductionToSVE (AArch64ISD::FMINNMV_PRED, Op, DAG);
9857
+ return LowerReductionToSVE (AArch64ISD::FMINNMV_PRED, Op, DAG);
9842
9858
default:
9843
9859
llvm_unreachable("Unhandled fixed length reduction");
9844
9860
}
@@ -16333,20 +16349,56 @@ SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
16333
16349
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
16334
16350
}
16335
16351
16336
- SDValue AArch64TargetLowering::LowerFixedLengthReductionToSVE(unsigned Opcode,
16337
- SDValue ScalarOp, SelectionDAG &DAG) const {
16352
+ SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
16353
+ SelectionDAG &DAG) const {
16354
+ SDLoc DL(ReduceOp);
16355
+ SDValue Op = ReduceOp.getOperand(0);
16356
+ EVT OpVT = Op.getValueType();
16357
+ EVT VT = ReduceOp.getValueType();
16358
+
16359
+ if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
16360
+ return SDValue();
16361
+
16362
+ SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
16363
+
16364
+ switch (ReduceOp.getOpcode()) {
16365
+ default:
16366
+ return SDValue();
16367
+ case ISD::VECREDUCE_OR:
16368
+ return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
16369
+ case ISD::VECREDUCE_AND: {
16370
+ Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
16371
+ return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
16372
+ }
16373
+ case ISD::VECREDUCE_XOR: {
16374
+ SDValue ID =
16375
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
16376
+ SDValue Cntp =
16377
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
16378
+ return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
16379
+ }
16380
+ }
16381
+
16382
+ return SDValue();
16383
+ }
16384
+
16385
+ SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
16386
+ SDValue ScalarOp,
16387
+ SelectionDAG &DAG) const {
16338
16388
SDLoc DL(ScalarOp);
16339
16389
SDValue VecOp = ScalarOp.getOperand(0);
16340
16390
EVT SrcVT = VecOp.getValueType();
16341
16391
16342
- SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
16343
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
16344
- VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
16392
+ if (useSVEForFixedLengthVectorVT(SrcVT, true)) {
16393
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
16394
+ VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
16395
+ }
16345
16396
16346
16397
// UADDV always returns an i64 result.
16347
16398
EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
16348
16399
SrcVT.getVectorElementType();
16349
16400
16401
+ SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
16350
16402
SDValue Rdx = DAG.getNode(Opcode, DL, getPackedSVEVectorVT(ResVT), Pg, VecOp);
16351
16403
SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
16352
16404
Rdx, DAG.getConstant(0, DL, MVT::i64));
0 commit comments