Skip to content

Commit fd66c5f

Browse files
committed
[AArch64] Enable fixed-length vector support for partial-reductions
1 parent d8dfd42 commit fd66c5f

File tree

2 files changed

+272
-546
lines changed

2 files changed

+272
-546
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19301930
Custom);
19311931
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19321932
Custom);
1933+
1934+
if (EnablePartialReduceNodes) {
1935+
// Must be lowered to SVE instructions.
1936+
setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
1937+
setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
1938+
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1939+
setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
1940+
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
1941+
setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
1942+
}
19331943
}
19341944
}
19351945

@@ -2225,6 +2235,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22252235
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22262236
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22272237

2238+
if (EnablePartialReduceNodes) {
2239+
unsigned NumElts = VT.getVectorNumElements();
2240+
if (VT.getVectorElementType() == MVT::i64) {
2241+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
2242+
Custom);
2243+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
2244+
Custom);
2245+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
2246+
Custom);
2247+
} else if (VT.getVectorElementType() == MVT::i32) {
2248+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
2249+
Custom);
2250+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
2251+
Custom);
2252+
} else if (VT.getVectorElementType() == MVT::i16) {
2253+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
2254+
Custom);
2255+
}
2256+
}
2257+
22282258
// Lower fixed length vector operations to scalable equivalents.
22292259
setOperationAction(ISD::ABDS, VT, Default);
22302260
setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29254,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2922429254
SDValue
2922529255
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2922629256
SelectionDAG &DAG) const {
29227-
bool Scalable = Op.getValueType().isScalableVector();
29228-
29229-
assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29230-
"SVE or StreamingSVE must be available when using scalable vectors.");
29231-
assert((Scalable || Subtarget->hasDotProd()) &&
29232-
"Dotprod must be available when targeting NEON dot product "
29233-
"instructions.");
29234-
2923529257
SDLoc DL(Op);
2923629258

2923729259
SDValue Acc = Op.getOperand(0);
2923829260
SDValue LHS = Op.getOperand(1);
2923929261
SDValue RHS = Op.getOperand(2);
2924029262
EVT ResultVT = Op.getValueType();
29263+
EVT OrigResultVT = ResultVT;
29264+
EVT OpVT = LHS.getValueType();
2924129265

29242-
assert((Scalable && ResultVT == MVT::nxv2i64 &&
29243-
LHS.getValueType() == MVT::nxv16i8) ||
29244-
(!Scalable && ResultVT == MVT::v2i64 &&
29245-
LHS.getValueType() == MVT::v16i8));
29266+
bool ConvertToScalable =
29267+
ResultVT.isFixedLengthVector() &&
29268+
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2924629269

29247-
EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29270+
if (ConvertToScalable) {
29271+
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29272+
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29273+
Acc = convertToScalableVector(DAG, ResultVT, Acc);
29274+
LHS = convertToScalableVector(DAG, OpVT, LHS);
29275+
RHS = convertToScalableVector(DAG, OpVT, RHS);
29276+
Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29277+
}
29278+
29279+
// Two-way and four-way partial reductions are supported by patterns.
29280+
// We only need to handle the 8-way partial reduction.
29281+
if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29282+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29283+
: Op;
29284+
29285+
EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2924829286
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2924929287
DAG.getConstant(0, DL, DotVT), LHS, RHS);
2925029288

29289+
SDValue Res;
2925129290
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29252-
if (Scalable &&
29253-
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29291+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2925429292
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2925529293
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2925629294
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29257-
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29258-
}
29259-
29260-
// Fold (nx)v4i32 into (nx)v2i64
29261-
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29262-
if (IsUnsigned) {
29263-
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29264-
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29295+
Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2926529296
} else {
29266-
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29267-
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29297+
// Fold (nx)v4i32 into (nx)v2i64
29298+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29299+
if (IsUnsigned) {
29300+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29301+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29302+
} else {
29303+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29304+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29305+
}
29306+
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29307+
Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2926829308
}
29269-
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29270-
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29309+
29310+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29311+
: Res;
2927129312
}
2927229313

2927329314
SDValue

0 commit comments

Comments
 (0)