Skip to content

Commit 12bd049

Browse files
[AArch64] Enable fixed-length vector support for partial-reductions (#142032)
This enables the use of the [us]dot, [us]add[wt] and [us]mlal[bt] instructions in Streaming mode, and for wider vectors when the runtime vector length is known to be 256bits or larger.
1 parent d16ecad commit 12bd049

File tree

2 files changed

+863
-27
lines changed

2 files changed

+863
-27
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19351935
Custom);
19361936
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19371937
Custom);
1938+
1939+
if (EnablePartialReduceNodes) {
1940+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941+
ISD::PARTIAL_REDUCE_UMLA};
1942+
// Must be lowered to SVE instructions.
1943+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948+
setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949+
}
19381950
}
19391951
}
19401952

@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22302242
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22312243
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22322244

2245+
if (EnablePartialReduceNodes) {
2246+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247+
ISD::PARTIAL_REDUCE_UMLA};
2248+
unsigned NumElts = VT.getVectorNumElements();
2249+
if (VT.getVectorElementType() == MVT::i64) {
2250+
setPartialReduceMLAAction(MLAOps, VT,
2251+
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252+
setPartialReduceMLAAction(
2253+
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254+
setPartialReduceMLAAction(
2255+
MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256+
} else if (VT.getVectorElementType() == MVT::i32) {
2257+
setPartialReduceMLAAction(MLAOps, VT,
2258+
MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259+
setPartialReduceMLAAction(
2260+
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261+
} else if (VT.getVectorElementType() == MVT::i16) {
2262+
setPartialReduceMLAAction(MLAOps, VT,
2263+
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264+
}
2265+
}
2266+
22332267
// Lower fixed length vector operations to scalable equivalents.
22342268
setOperationAction(ISD::ABDS, VT, Default);
22352269
setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2925129285
SDValue
2925229286
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2925329287
SelectionDAG &DAG) const {
29254-
bool Scalable = Op.getValueType().isScalableVector();
29255-
29256-
assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29257-
"SVE or StreamingSVE must be available when using scalable vectors.");
29258-
assert((Scalable || Subtarget->hasDotProd()) &&
29259-
"Dotprod must be available when targeting NEON dot product "
29260-
"instructions.");
29261-
2926229288
SDLoc DL(Op);
2926329289

2926429290
SDValue Acc = Op.getOperand(0);
2926529291
SDValue LHS = Op.getOperand(1);
2926629292
SDValue RHS = Op.getOperand(2);
2926729293
EVT ResultVT = Op.getValueType();
29294+
EVT OrigResultVT = ResultVT;
29295+
EVT OpVT = LHS.getValueType();
2926829296

29269-
assert((Scalable && ResultVT == MVT::nxv2i64 &&
29270-
LHS.getValueType() == MVT::nxv16i8) ||
29271-
(!Scalable && ResultVT == MVT::v2i64 &&
29272-
LHS.getValueType() == MVT::v16i8));
29297+
bool ConvertToScalable =
29298+
ResultVT.isFixedLengthVector() &&
29299+
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2927329300

29274-
EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29301+
if (ConvertToScalable) {
29302+
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29303+
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29304+
Acc = convertToScalableVector(DAG, ResultVT, Acc);
29305+
LHS = convertToScalableVector(DAG, OpVT, LHS);
29306+
RHS = convertToScalableVector(DAG, OpVT, RHS);
29307+
Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29308+
}
29309+
29310+
// Two-way and four-way partial reductions are supported by patterns.
29311+
// We only need to handle the 8-way partial reduction.
29312+
if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29313+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29314+
: Op;
29315+
29316+
EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2927529317
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2927629318
DAG.getConstant(0, DL, DotVT), LHS, RHS);
2927729319

29320+
SDValue Res;
2927829321
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29279-
if (Scalable &&
29280-
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29322+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2928129323
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2928229324
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2928329325
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29284-
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29285-
}
29286-
29287-
// Fold (nx)v4i32 into (nx)v2i64
29288-
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29289-
if (IsUnsigned) {
29290-
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29291-
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29326+
Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2929229327
} else {
29293-
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29294-
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29328+
// Fold (nx)v4i32 into (nx)v2i64
29329+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29330+
if (IsUnsigned) {
29331+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29332+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29333+
} else {
29334+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29335+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29336+
}
29337+
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29338+
Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2929529339
}
29296-
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29297-
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29340+
29341+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29342+
: Res;
2929829343
}
2929929344

2930029345
SDValue

0 commit comments

Comments
 (0)