@@ -1930,6 +1930,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1930
1930
Custom);
1931
1931
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
1932
1932
Custom);
1933
+
1934
+ // Must be lowered to SVE instructions.
1935
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
1936
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
1937
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1938
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
1939
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
1940
+ setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
1933
1941
}
1934
1942
}
1935
1943
@@ -2225,6 +2233,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
2225
2233
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2226
2234
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2227
2235
2236
+ if (EnablePartialReduceNodes) {
2237
+ unsigned NumElts = VT.getVectorNumElements();
2238
+ if (VT.getVectorElementType() == MVT::i64) {
2239
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
2240
+ Custom);
2241
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
2242
+ Custom);
2243
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
2244
+ Custom);
2245
+ } else if (VT.getVectorElementType() == MVT::i32) {
2246
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
2247
+ Custom);
2248
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
2249
+ Custom);
2250
+ } else if (VT.getVectorElementType() == MVT::i16) {
2251
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
2252
+ Custom);
2253
+ }
2254
+ }
2255
+
2228
2256
// Lower fixed length vector operations to scalable equivalents.
2229
2257
setOperationAction(ISD::ABDS, VT, Default);
2230
2258
setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29252,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29224
29252
SDValue
29225
29253
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29226
29254
SelectionDAG &DAG) const {
29227
- bool Scalable = Op.getValueType().isScalableVector();
29228
-
29229
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29230
- "SVE or StreamingSVE must be available when using scalable vectors.");
29231
- assert((Scalable || Subtarget->hasDotProd()) &&
29232
- "Dotprod must be available when targeting NEON dot product "
29233
- "instructions.");
29234
-
29235
29255
SDLoc DL(Op);
29236
29256
29237
29257
SDValue Acc = Op.getOperand(0);
29238
29258
SDValue LHS = Op.getOperand(1);
29239
29259
SDValue RHS = Op.getOperand(2);
29240
29260
EVT ResultVT = Op.getValueType();
29261
+ EVT OrigResultVT = ResultVT;
29262
+ EVT OpVT = LHS.getValueType();
29241
29263
29242
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29243
- LHS.getValueType() == MVT::nxv16i8) ||
29244
- (!Scalable && ResultVT == MVT::v2i64 &&
29245
- LHS.getValueType() == MVT::v16i8));
29264
+ bool ConvertToScalable =
29265
+ ResultVT.isFixedLengthVector() &&
29266
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
29246
29267
29247
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29268
+ if (ConvertToScalable) {
29269
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29270
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29271
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29272
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29273
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29274
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29275
+ }
29276
+
29277
+ // Two-way and four-way partial reductions are supported by patterns.
29278
+ // We only need to handle the 8-way partial reduction.
29279
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29280
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29281
+ : Op;
29282
+
29283
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
29248
29284
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29249
29285
DAG.getConstant(0, DL, DotVT), LHS, RHS);
29250
29286
29287
+ SDValue Res;
29251
29288
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29252
- if (Scalable &&
29253
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29289
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29254
29290
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29255
29291
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29256
29292
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29257
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29258
- }
29259
-
29260
- // Fold (nx)v4i32 into (nx)v2i64
29261
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29262
- if (IsUnsigned) {
29263
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29264
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29293
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29265
29294
} else {
29266
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29267
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29295
+ // Fold (nx)v4i32 into (nx)v2i64
29296
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29297
+ if (IsUnsigned) {
29298
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29299
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29300
+ } else {
29301
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29302
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29303
+ }
29304
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29305
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29268
29306
}
29269
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29270
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29307
+
29308
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29309
+ : Res;
29271
29310
}
29272
29311
29273
29312
SDValue
0 commit comments