@@ -1930,6 +1930,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1930
1930
Custom);
1931
1931
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
1932
1932
Custom);
1933
+
1934
+ if (EnablePartialReduceNodes) {
1935
+ // Must be lowered to SVE instructions.
1936
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v4i32, Custom);
1937
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Custom);
1938
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1939
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v8i16, Custom);
1940
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Custom);
1941
+ setPartialReduceMLAAction(MVT::v8i16, MVT::v16i8, Custom);
1942
+ }
1933
1943
}
1934
1944
}
1935
1945
@@ -2225,6 +2235,26 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
2225
2235
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2226
2236
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2227
2237
2238
+ if (EnablePartialReduceNodes) {
2239
+ unsigned NumElts = VT.getVectorNumElements();
2240
+ if (VT.getVectorElementType() == MVT::i64) {
2241
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 8),
2242
+ Custom);
2243
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 4),
2244
+ Custom);
2245
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i32, NumElts * 2),
2246
+ Custom);
2247
+ } else if (VT.getVectorElementType() == MVT::i32) {
2248
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 4),
2249
+ Custom);
2250
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i16, NumElts * 2),
2251
+ Custom);
2252
+ } else if (VT.getVectorElementType() == MVT::i16) {
2253
+ setPartialReduceMLAAction(VT, MVT::getVectorVT(MVT::i8, NumElts * 2),
2254
+ Custom);
2255
+ }
2256
+ }
2257
+
2228
2258
// Lower fixed length vector operations to scalable equivalents.
2229
2259
setOperationAction(ISD::ABDS, VT, Default);
2230
2260
setOperationAction(ISD::ABDU, VT, Default);
@@ -29224,50 +29254,61 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29224
29254
SDValue
29225
29255
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29226
29256
SelectionDAG &DAG) const {
29227
- bool Scalable = Op.getValueType().isScalableVector();
29228
-
29229
- assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29230
- "SVE or StreamingSVE must be available when using scalable vectors.");
29231
- assert((Scalable || Subtarget->hasDotProd()) &&
29232
- "Dotprod must be available when targeting NEON dot product "
29233
- "instructions.");
29234
-
29235
29257
SDLoc DL(Op);
29236
29258
29237
29259
SDValue Acc = Op.getOperand(0);
29238
29260
SDValue LHS = Op.getOperand(1);
29239
29261
SDValue RHS = Op.getOperand(2);
29240
29262
EVT ResultVT = Op.getValueType();
29263
+ EVT OrigResultVT = ResultVT;
29264
+ EVT OpVT = LHS.getValueType();
29241
29265
29242
- assert((Scalable && ResultVT == MVT::nxv2i64 &&
29243
- LHS.getValueType() == MVT::nxv16i8) ||
29244
- (!Scalable && ResultVT == MVT::v2i64 &&
29245
- LHS.getValueType() == MVT::v16i8));
29266
+ bool ConvertToScalable =
29267
+ ResultVT.isFixedLengthVector() &&
29268
+ useSVEForFixedLengthVectorVT(ResultVT, /*OverrideNEON=*/true);
29246
29269
29247
- EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29270
+ if (ConvertToScalable) {
29271
+ ResultVT = getContainerForFixedLengthVector(DAG, ResultVT);
29272
+ OpVT = getContainerForFixedLengthVector(DAG, LHS.getValueType());
29273
+ Acc = convertToScalableVector(DAG, ResultVT, Acc);
29274
+ LHS = convertToScalableVector(DAG, OpVT, LHS);
29275
+ RHS = convertToScalableVector(DAG, OpVT, RHS);
29276
+ Op = DAG.getNode(Op.getOpcode(), DL, ResultVT, {Acc, LHS, RHS});
29277
+ }
29278
+
29279
+ // Two-way and four-way partial reductions are supported by patterns.
29280
+ // We only need to handle the 8-way partial reduction.
29281
+ if (ResultVT.getScalarType() != MVT::i64 || OpVT.getScalarType() != MVT::i8)
29282
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Op)
29283
+ : Op;
29284
+
29285
+ EVT DotVT = ResultVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
29248
29286
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29249
29287
DAG.getConstant(0, DL, DotVT), LHS, RHS);
29250
29288
29289
+ SDValue Res;
29251
29290
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29252
- if (Scalable &&
29253
- (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29291
+ if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29254
29292
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29255
29293
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29256
29294
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29257
- return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29258
- }
29259
-
29260
- // Fold (nx)v4i32 into (nx)v2i64
29261
- auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29262
- if (IsUnsigned) {
29263
- DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29264
- DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29295
+ Res = DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29265
29296
} else {
29266
- DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29267
- DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29297
+ // Fold (nx)v4i32 into (nx)v2i64
29298
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29299
+ if (IsUnsigned) {
29300
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29301
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29302
+ } else {
29303
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29304
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29305
+ }
29306
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29307
+ Res = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29268
29308
}
29269
- auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29270
- return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29309
+
29310
+ return ConvertToScalable ? convertFromScalableVector(DAG, OrigResultVT, Res)
29311
+ : Res;
29271
29312
}
29272
29313
29273
29314
SDValue
0 commit comments