Skip to content

Commit f2412d3

Browse files
[SVE][CodeGen] Lower scalable integer vector reductions
This patch uses the existing LowerFixedLengthReductionToSVE function to also lower scalable vector reductions. A separate function has been added to lower VECREDUCE_AND & VECREDUCE_OR operations with predicate types using ptest. Lowering scalable floating-point reductions will be addressed in a follow up patch, for now these will hit the assertion added to expandVecReduce() in TargetLowering. Reviewed By: paulwalker-arm Differential Revision: https://reviews.llvm.org/D89382
1 parent f202d32 commit f2412d3

File tree

10 files changed

+1284
-28
lines changed

10 files changed

+1284
-28
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20857,7 +20857,7 @@ SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
2085720857
unsigned Opcode = N->getOpcode();
2085820858

2085920859
// VECREDUCE over 1-element vector is just an extract.
20860-
if (VT.getVectorNumElements() == 1) {
20860+
if (VT.getVectorElementCount().isScalar()) {
2086120861
SDLoc dl(N);
2086220862
SDValue Res =
2086320863
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
33233323
SDValue InVec = Op.getOperand(0);
33243324
SDValue EltNo = Op.getOperand(1);
33253325
EVT VecVT = InVec.getValueType();
3326+
// computeKnownBits not yet implemented for scalable vectors.
3327+
if (VecVT.isScalableVector())
3328+
break;
33263329
const unsigned EltBitWidth = VecVT.getScalarSizeInBits();
33273330
const unsigned NumSrcElts = VecVT.getVectorNumElements();
33283331

@@ -4809,6 +4812,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
48094812
case ISD::VSCALE:
48104813
assert(VT == Operand.getValueType() && "Unexpected VT!");
48114814
break;
4815+
case ISD::VECREDUCE_SMIN:
4816+
case ISD::VECREDUCE_UMAX:
4817+
if (Operand.getValueType().getScalarType() == MVT::i1)
4818+
return getNode(ISD::VECREDUCE_OR, DL, VT, Operand);
4819+
break;
4820+
case ISD::VECREDUCE_SMAX:
4821+
case ISD::VECREDUCE_UMIN:
4822+
if (Operand.getValueType().getScalarType() == MVT::i1)
4823+
return getNode(ISD::VECREDUCE_AND, DL, VT, Operand);
4824+
break;
48124825
}
48134826

48144827
SDNode *N;
@@ -5318,10 +5331,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
53185331
case ISD::MULHS:
53195332
case ISD::SDIV:
53205333
case ISD::SREM:
5321-
case ISD::SMIN:
5322-
case ISD::SMAX:
5323-
case ISD::UMIN:
5324-
case ISD::UMAX:
53255334
case ISD::SADDSAT:
53265335
case ISD::SSUBSAT:
53275336
case ISD::UADDSAT:
@@ -5330,6 +5339,22 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
53305339
assert(N1.getValueType() == N2.getValueType() &&
53315340
N1.getValueType() == VT && "Binary operator types must match!");
53325341
break;
5342+
case ISD::SMIN:
5343+
case ISD::UMAX:
5344+
assert(VT.isInteger() && "This operator does not apply to FP types!");
5345+
assert(N1.getValueType() == N2.getValueType() &&
5346+
N1.getValueType() == VT && "Binary operator types must match!");
5347+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
5348+
return getNode(ISD::OR, DL, VT, N1, N2);
5349+
break;
5350+
case ISD::SMAX:
5351+
case ISD::UMIN:
5352+
assert(VT.isInteger() && "This operator does not apply to FP types!");
5353+
assert(N1.getValueType() == N2.getValueType() &&
5354+
N1.getValueType() == VT && "Binary operator types must match!");
5355+
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
5356+
return getNode(ISD::AND, DL, VT, N1, N2);
5357+
break;
53335358
case ISD::FADD:
53345359
case ISD::FSUB:
53355360
case ISD::FMUL:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8000,6 +8000,10 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
80008000
SDValue Op = Node->getOperand(0);
80018001
EVT VT = Op.getValueType();
80028002

8003+
if (VT.isScalableVector())
8004+
report_fatal_error(
8005+
"Expanding reductions for scalable vectors is undefined.");
8006+
80038007
// Try to use a shuffle reduction for power of two vectors.
80048008
if (VT.isPow2VectorType()) {
80058009
while (VT.getVectorNumElements() > 1) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
10131013
setOperationAction(ISD::SHL, VT, Custom);
10141014
setOperationAction(ISD::SRL, VT, Custom);
10151015
setOperationAction(ISD::SRA, VT, Custom);
1016+
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1017+
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1018+
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1019+
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1020+
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1021+
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1022+
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1023+
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
10161024
}
10171025

10181026
// Illegal unpacked integer vector types.
@@ -1027,6 +1035,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
10271035
setOperationAction(ISD::SETCC, VT, Custom);
10281036
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
10291037
setOperationAction(ISD::TRUNCATE, VT, Custom);
1038+
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1039+
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1040+
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
10301041

10311042
// There are no legal MVT::nxv16f## based types.
10321043
if (VT != MVT::nxv16i1) {
@@ -9815,30 +9826,35 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
98159826
Op.getOpcode() == ISD::VECREDUCE_FADD ||
98169827
(Op.getOpcode() != ISD::VECREDUCE_ADD &&
98179828
SrcVT.getVectorElementType() == MVT::i64);
9818-
if (useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
9829+
if (SrcVT.isScalableVector() ||
9830+
useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
9831+
9832+
if (SrcVT.getVectorElementType() == MVT::i1)
9833+
return LowerPredReductionToSVE(Op, DAG);
9834+
98199835
switch (Op.getOpcode()) {
98209836
case ISD::VECREDUCE_ADD:
9821-
return LowerFixedLengthReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
9837+
return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
98229838
case ISD::VECREDUCE_AND:
9823-
return LowerFixedLengthReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
9839+
return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
98249840
case ISD::VECREDUCE_OR:
9825-
return LowerFixedLengthReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
9841+
return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
98269842
case ISD::VECREDUCE_SMAX:
9827-
return LowerFixedLengthReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
9843+
return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
98289844
case ISD::VECREDUCE_SMIN:
9829-
return LowerFixedLengthReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
9845+
return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
98309846
case ISD::VECREDUCE_UMAX:
9831-
return LowerFixedLengthReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
9847+
return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
98329848
case ISD::VECREDUCE_UMIN:
9833-
return LowerFixedLengthReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
9849+
return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
98349850
case ISD::VECREDUCE_XOR:
9835-
return LowerFixedLengthReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
9851+
return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
98369852
case ISD::VECREDUCE_FADD:
9837-
return LowerFixedLengthReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
9853+
return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
98389854
case ISD::VECREDUCE_FMAX:
9839-
return LowerFixedLengthReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
9855+
return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
98409856
case ISD::VECREDUCE_FMIN:
9841-
return LowerFixedLengthReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
9857+
return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
98429858
default:
98439859
llvm_unreachable("Unhandled fixed length reduction");
98449860
}
@@ -16333,20 +16349,56 @@ SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
1633316349
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
1633416350
}
1633516351

16336-
SDValue AArch64TargetLowering::LowerFixedLengthReductionToSVE(unsigned Opcode,
16337-
SDValue ScalarOp, SelectionDAG &DAG) const {
16352+
SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
16353+
SelectionDAG &DAG) const {
16354+
SDLoc DL(ReduceOp);
16355+
SDValue Op = ReduceOp.getOperand(0);
16356+
EVT OpVT = Op.getValueType();
16357+
EVT VT = ReduceOp.getValueType();
16358+
16359+
if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
16360+
return SDValue();
16361+
16362+
SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
16363+
16364+
switch (ReduceOp.getOpcode()) {
16365+
default:
16366+
return SDValue();
16367+
case ISD::VECREDUCE_OR:
16368+
return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
16369+
case ISD::VECREDUCE_AND: {
16370+
Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
16371+
return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
16372+
}
16373+
case ISD::VECREDUCE_XOR: {
16374+
SDValue ID =
16375+
DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
16376+
SDValue Cntp =
16377+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
16378+
return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
16379+
}
16380+
}
16381+
16382+
return SDValue();
16383+
}
16384+
16385+
SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
16386+
SDValue ScalarOp,
16387+
SelectionDAG &DAG) const {
1633816388
SDLoc DL(ScalarOp);
1633916389
SDValue VecOp = ScalarOp.getOperand(0);
1634016390
EVT SrcVT = VecOp.getValueType();
1634116391

16342-
SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
16343-
EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
16344-
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
16392+
if (useSVEForFixedLengthVectorVT(SrcVT, true)) {
16393+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
16394+
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
16395+
}
1634516396

1634616397
// UADDV always returns an i64 result.
1634716398
EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
1634816399
SrcVT.getVectorElementType();
1634916400

16401+
SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
1635016402
SDValue Rdx = DAG.getNode(Opcode, DL, getPackedSVEVectorVT(ResVT), Pg, VecOp);
1635116403
SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
1635216404
Rdx, DAG.getConstant(0, DL, MVT::i64));

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,9 @@ class AArch64TargetLowering : public TargetLowering {
933933
SelectionDAG &DAG) const;
934934
SDValue LowerFixedLengthVectorLoadToSVE(SDValue Op, SelectionDAG &DAG) const;
935935
SDValue LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp, SelectionDAG &DAG) const;
936-
SDValue LowerFixedLengthReductionToSVE(unsigned Opcode, SDValue ScalarOp,
937-
SelectionDAG &DAG) const;
936+
SDValue LowerPredReductionToSVE(SDValue ScalarOp, SelectionDAG &DAG) const;
937+
SDValue LowerReductionToSVE(unsigned Opcode, SDValue ScalarOp,
938+
SelectionDAG &DAG) const;
938939
SDValue LowerFixedLengthVectorSelectToSVE(SDValue Op, SelectionDAG &DAG) const;
939940
SDValue LowerFixedLengthVectorSetccToSVE(SDValue Op, SelectionDAG &DAG) const;
940941
SDValue LowerFixedLengthVectorStoreToSVE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)