Skip to content

Commit 0fa3bd5

Browse files
committed
[AArch64] Enable fixed-length vector support for partial-reductions
1 parent 0c36582 commit 0fa3bd5

File tree

2 files changed

+276
-546
lines changed

2 files changed

+276
-546
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19351935
Custom);
19361936
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19371937
Custom);
1938+
1939+
if (EnablePartialReduceNodes) {
1940+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941+
ISD::PARTIAL_REDUCE_UMLA};
1942+
// Must be lowered to SVE instructions.
1943+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948+
setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949+
}
19381950
}
19391951
}
19401952

@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22302242
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22312243
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22322244

2245+
if (EnablePartialReduceNodes) {
2246+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247+
ISD::PARTIAL_REDUCE_UMLA};
2248+
unsigned NumElts = VT.getVectorNumElements();
2249+
if (VT.getVectorElementType() == MVT::i64) {
2250+
setPartialReduceMLAAction(MLAOps, VT,
2251+
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252+
setPartialReduceMLAAction(
2253+
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254+
setPartialReduceMLAAction(
2255+
MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256+
} else if (VT.getVectorElementType() == MVT::i32) {
2257+
setPartialReduceMLAAction(MLAOps, VT,
2258+
MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259+
setPartialReduceMLAAction(
2260+
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261+
} else if (VT.getVectorElementType() == MVT::i16) {
2262+
setPartialReduceMLAAction(MLAOps, VT,
2263+
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264+
}
2265+
}
2266+
22332267
// Lower fixed length vector operations to scalable equivalents.
22342268
setOperationAction(ISD::ABDS, VT, Default);
22352269
setOperationAction(ISD::ABDU, VT, Default);
@@ -29229,50 +29263,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2922929263
SDValue
2923029264
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2923129265
SelectionDAG &DAG) const {
29232-
bool Scalable = Op.getValueType().isScalableVector();
29233-
29234-
assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29235-
"SVE or StreamingSVE must be available when using scalable vectors.");
29236-
assert((Scalable || Subtarget->hasDotProd()) &&
29237-
"Dotprod must be available when targeting NEON dot product "
29238-
"instructions.");
29239-
2924029266
SDLoc DL(Op);
2924129267

2924229268
SDValue Acc = Op.getOperand(0);
2924329269
SDValue LHS = Op.getOperand(1);
2924429270
SDValue RHS = Op.getOperand(2);
2924529271
EVT ResultVT = Op.getValueType();
29272+
EVT OrigResultVT = ResultVT;
29273+
EVT OpVT = LHS.getValueType();
2924629274

29247-
assert((Scalable && ResultVT == MVT::nxv2i64 &&
29248-
LHS.getValueType() == MVT::nxv16i8) ||
29249-
(!Scalable && ResultVT == MVT::v2i64 &&
29250-
LHS.getValueType() == MVT::v16i8));
29275+
bool ConvertToScalable =
29276+
ResultVT.isFixedLengthVector() &&
29277+
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2925129278

29252-
EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29279+
if (ConvertToScalable) {
29280+
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29281+
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29282+
Acc = convertToScalableVector(DAG, ResultVT, Acc);
29283+
LHS = convertToScalableVector(DAG, OpVT, LHS);
29284+
RHS = convertToScalableVector(DAG, OpVT, RHS);
29285+
Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29286+
}
29287+
29288+
// Two-way and four-way partial reductions are supported by patterns.
29289+
// We only need to handle the 8-way partial reduction.
29290+
if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29291+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29292+
: Op;
29293+
29294+
EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2925329295
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2925429296
DAG.getConstant(0, DL, DotVT), LHS, RHS);
2925529297

29298+
SDValue Res;
2925629299
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29257-
if (Scalable &&
29258-
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29300+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2925929301
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2926029302
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2926129303
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29262-
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29263-
}
29264-
29265-
// Fold (nx)v4i32 into (nx)v2i64
29266-
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29267-
if (IsUnsigned) {
29268-
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29269-
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29304+
Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2927029305
} else {
29271-
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29272-
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29306+
// Fold (nx)v4i32 into (nx)v2i64
29307+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29308+
if (IsUnsigned) {
29309+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29310+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29311+
} else {
29312+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29313+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29314+
}
29315+
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29316+
Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2927329317
}
29274-
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29275-
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29318+
29319+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29320+
: Res;
2927629321
}
2927729322

2927829323
SDValue

0 commit comments

Comments
 (0)