@@ -1935,6 +1935,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1935
1935
Custom);
1936
1936
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
1937
1937
Custom);
1938
+
1939
+ if (EnablePartialReduceNodes) {
1940
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1941
+ ISD::PARTIAL_REDUCE_UMLA};
1942
+ // Must be lowered to SVE instructions.
1943
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1944
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1945
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1946
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1947
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1948
+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1949
+ }
1938
1950
}
1939
1951
}
1940
1952
@@ -2230,6 +2242,28 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
2230
2242
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2231
2243
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2232
2244
2245
+ if (EnablePartialReduceNodes) {
2246
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2247
+ ISD::PARTIAL_REDUCE_UMLA};
2248
+ unsigned NumElts = VT.getVectorNumElements();
2249
+ if (VT.getVectorElementType() == MVT::i64) {
2250
+ setPartialReduceMLAAction(MLAOps, VT,
2251
+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2252
+ setPartialReduceMLAAction(
2253
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2254
+ setPartialReduceMLAAction(
2255
+ MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2256
+ } else if (VT.getVectorElementType() == MVT::i32) {
2257
+ setPartialReduceMLAAction(MLAOps, VT,
2258
+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2259
+ setPartialReduceMLAAction(
2260
+ MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2261
+ } else if (VT.getVectorElementType() == MVT::i16) {
2262
+ setPartialReduceMLAAction(MLAOps, VT,
2263
+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2264
+ }
2265
+ }
2266
+
2233
2267
// Lower fixed length vector operations to scalable equivalents.
2234
2268
setOperationAction(ISD::ABDS, VT, Default);
2235
2269
setOperationAction(ISD::ABDU, VT, Default);
@@ -29229,50 +29263,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29229
29263
SDValue
29230
29264
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29231
29265
SelectionDAG &DAG) const {
29232
- bool Scalable = Op.getValueType().isScalableVector();
29233
-
29234
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29235
- "SVE or StreamingSVE must be available when using scalable vectors.");
29236
- assert((Scalable || Subtarget->hasDotProd()) &&
29237
- "Dotprod must be available when targeting NEON dot product "
29238
- "instructions.");
29239
-
29240
29266
SDLoc DL(Op);
29241
29267
29242
29268
SDValue Acc = Op.getOperand(0);
29243
29269
SDValue LHS = Op.getOperand(1);
29244
29270
SDValue RHS = Op.getOperand(2);
29245
29271
EVT ResultVT = Op.getValueType();
29272
+ EVT OrigResultVT = ResultVT;
29273
+ EVT OpVT = LHS.getValueType();
29246
29274
29247
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29248
- LHS.getValueType() == MVT::nxv16i8) ||
29249
- (!Scalable && ResultVT == MVT::v2i64 &&
29250
- LHS.getValueType() == MVT::v16i8));
29275
+ bool ConvertToScalable =
29276
+ ResultVT.isFixedLengthVector() &&
29277
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
29251
29278
29252
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29279
+ if (ConvertToScalable) {
29280
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29281
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29282
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29283
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29284
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29285
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29286
+ }
29287
+
29288
+ // Two-way and four-way partial reductions are supported by patterns.
29289
+ // We only need to handle the 8-way partial reduction.
29290
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29291
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29292
+ : Op;
29293
+
29294
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
29253
29295
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29254
29296
DAG.getConstant(0, DL, DotVT), LHS, RHS);
29255
29297
29298
+ SDValue Res;
29256
29299
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29257
- if (Scalable &&
29258
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29300
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29259
29301
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29260
29302
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29261
29303
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29262
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29263
- }
29264
-
29265
- // Fold (nx)v4i32 into (nx)v2i64
29266
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29267
- if (IsUnsigned) {
29268
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29269
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29304
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29270
29305
} else {
29271
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29272
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29306
+ // Fold (nx)v4i32 into (nx)v2i64
29307
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29308
+ if (IsUnsigned) {
29309
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29310
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29311
+ } else {
29312
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29313
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29314
+ }
29315
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29316
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29273
29317
}
29274
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29275
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29318
+
29319
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29320
+ : Res;
29276
29321
}
29277
29322
29278
29323
SDValue
0 commit comments