Skip to content

Commit 88aec0f

Browse files
committed
[AArch64] Enable fixed-length vector support for partial-reductions
1 parent d8dfd42 commit 88aec0f

File tree

2 files changed

+270
-546
lines changed

2 files changed

+270
-546
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1930,6 +1930,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19301930
Custom);
19311931
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19321932
Custom);
1933+
1934+
// Must be lowered to SVE instructions.
1935+
setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
1936+
setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
1937+
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1938+
setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
1939+
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
1940+
setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
19331941
}
19341942
}
19351943

@@ -2225,6 +2233,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22252233
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22262234
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22272235

2236+
if (EnablePartialReduceNodes) {
2237+
unsigned NumElts = VT.getVectorNumElements();
2238+
if (VT.getVectorElementType() == MVT::i64) {
2239+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
2240+
Custom);
2241+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
2242+
Custom);
2243+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
2244+
Custom);
2245+
} else if (VT.getVectorElementType() == MVT::i32) {
2246+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
2247+
Custom);
2248+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
2249+
Custom);
2250+
} else if (VT.getVectorElementType() == MVT::i16) {
2251+
setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
2252+
Custom);
2253+
}
2254+
}
2255+
22282256
// Lower fixed length vector operations to scalable equivalents.
22292257
setOperationAction(ISD::ABDS, VT, Default);
22302258
setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29252,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2922429252
SDValue
2922529253
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2922629254
SelectionDAG &DAG) const {
29227-
bool Scalable = Op.getValueType().isScalableVector();
29228-
29229-
assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29230-
"SVE or StreamingSVE must be available when using scalable vectors.");
29231-
assert((Scalable || Subtarget->hasDotProd()) &&
29232-
"Dotprod must be available when targeting NEON dot product "
29233-
"instructions.");
29234-
2923529255
SDLoc DL(Op);
2923629256

2923729257
SDValue Acc = Op.getOperand(0);
2923829258
SDValue LHS = Op.getOperand(1);
2923929259
SDValue RHS = Op.getOperand(2);
2924029260
EVT ResultVT = Op.getValueType();
29261+
EVT OrigResultVT = ResultVT;
29262+
EVT OpVT = LHS.getValueType();
2924129263

29242-
assert((Scalable && ResultVT == MVT::nxv2i64 &&
29243-
LHS.getValueType() == MVT::nxv16i8) ||
29244-
(!Scalable && ResultVT == MVT::v2i64 &&
29245-
LHS.getValueType() == MVT::v16i8));
29264+
bool ConvertToScalable =
29265+
ResultVT.isFixedLengthVector() &&
29266+
useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
2924629267

29247-
EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29268+
if (ConvertToScalable) {
29269+
ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29270+
OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29271+
Acc = convertToScalableVector(DAG, ResultVT, Acc);
29272+
LHS = convertToScalableVector(DAG, OpVT, LHS);
29273+
RHS = convertToScalableVector(DAG, OpVT, RHS);
29274+
Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29275+
}
29276+
29277+
// Two-way and four-way partial reductions are supported by patterns.
29278+
// We only need to handle the 8-way partial reduction.
29279+
if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29280+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29281+
: Op;
29282+
29283+
EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2924829284
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
2924929285
DAG.getConstant(0, DL, DotVT), LHS, RHS);
2925029286

29287+
SDValue Res;
2925129288
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29252-
if (Scalable &&
29253-
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29289+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
2925429290
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2925529291
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2925629292
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29257-
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29258-
}
29259-
29260-
// Fold (nx)v4i32 into (nx)v2i64
29261-
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29262-
if (IsUnsigned) {
29263-
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29264-
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29293+
Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2926529294
} else {
29266-
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29267-
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29295+
// Fold (nx)v4i32 into (nx)v2i64
29296+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29297+
if (IsUnsigned) {
29298+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29299+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29300+
} else {
29301+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29302+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29303+
}
29304+
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29305+
Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2926829306
}
29269-
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29270-
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29307+
29308+
return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29309+
: Res;
2927129310
}
2927229311

2927329312
SDValue

0 commit comments

Comments
 (0)