@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1935
1935
Custom);
1936
1936
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
1937
1937
Custom);
1938
+
1939
+ if (EnablePartialReduceNodes) {
1940
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941
+ ISD::PARTIAL_REDUCE_UMLA};
1942
+ // Must be lowered to SVE instructions.
1943
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948
+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949
+ }
1938
1950
}
1939
1951
}
1940
1952
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
2230
2242
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2231
2243
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2232
2244
2245
+ if (EnablePartialReduceNodes) {
2246
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247
+ ISD::PARTIAL_REDUCE_UMLA};
2248
+ unsigned NumElts = VT.getVectorNumElements();
2249
+ if (VT.getVectorElementType() == MVT::i64) {
2250
+ setPartialReduceMLAAction(MLAOps, VT,
2251
+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252
+ setPartialReduceMLAAction(
2253
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254
+ setPartialReduceMLAAction(
2255
+ MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256
+ } else if (VT.getVectorElementType() == MVT::i32) {
2257
+ setPartialReduceMLAAction(MLAOps, VT,
2258
+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259
+ setPartialReduceMLAAction(
2260
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261
+ } else if (VT.getVectorElementType() == MVT::i16) {
2262
+ setPartialReduceMLAAction(MLAOps, VT,
2263
+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264
+ }
2265
+ }
2266
+
2233
2267
// Lower fixed length vector operations to scalable equivalents.
2234
2268
setOperationAction(ISD::ABDS, VT, Default);
2235
2269
setOperationAction(ISD::ABDU, VT, Default);
@@ -29251,50 +29285,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29251
29285
SDValue
29252
29286
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29253
29287
SelectionDAG &DAG) const {
29254
- bool Scalable = Op.getValueType().isScalableVector();
29255
-
29256
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29257
- "SVE or StreamingSVE must be available when using scalable vectors.");
29258
- assert((Scalable || Subtarget->hasDotProd()) &&
29259
- "Dotprod must be available when targeting NEON dot product "
29260
- "instructions.");
29261
-
29262
29288
SDLoc DL(Op);
29263
29289
29264
29290
SDValue Acc = Op.getOperand(0);
29265
29291
SDValue LHS = Op.getOperand(1);
29266
29292
SDValue RHS = Op.getOperand(2);
29267
29293
EVT ResultVT = Op.getValueType();
29294
+ EVT OrigResultVT = ResultVT;
29295
+ EVT OpVT = LHS.getValueType();
29268
29296
29269
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29270
- LHS.getValueType() == MVT::nxv16i8) ||
29271
- (!Scalable && ResultVT == MVT::v2i64 &&
29272
- LHS.getValueType() == MVT::v16i8));
29297
+ bool ConvertToScalable =
29298
+ ResultVT.isFixedLengthVector() &&
29299
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
29273
29300
29274
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29301
+ if (ConvertToScalable) {
29302
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29303
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29304
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29305
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29306
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29307
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29308
+ }
29309
+
29310
+ // Two-way and four-way partial reductions are supported by patterns.
29311
+ // We only need to handle the 8-way partial reduction.
29312
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29313
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29314
+ : Op;
29315
+
29316
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
29275
29317
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29276
29318
DAG.getConstant(0, DL, DotVT), LHS, RHS);
29277
29319
29320
+ SDValue Res;
29278
29321
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29279
- if (Scalable &&
29280
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29322
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29281
29323
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29282
29324
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29283
29325
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29284
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29285
- }
29286
-
29287
- // Fold (nx)v4i32 into (nx)v2i64
29288
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29289
- if (IsUnsigned) {
29290
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29291
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29326
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29292
29327
} else {
29293
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29294
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29328
+ // Fold (nx)v4i32 into (nx)v2i64
29329
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29330
+ if (IsUnsigned) {
29331
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29332
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29333
+ } else {
29334
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29335
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29336
+ }
29337
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29338
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29295
29339
}
29296
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29297
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29340
+
29341
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29342
+ : Res;
29298
29343
}
29299
29344
29300
29345
SDValue
0 commit comments