@@ -1456,6 +1456,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1456
1456
// FADDP custom lowering
1457
1457
for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
1458
1458
setOperationAction(ISD::FADD, VT, Custom);
1459
+
1460
+ if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1461
+ setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1462
+ setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1463
+ setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1464
+ }
1465
+
1459
1466
} else /* !isNeonAvailable */ {
1460
1467
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1461
1468
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
@@ -29528,37 +29535,60 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29528
29535
}
29529
29536
29530
29537
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29531
- /// of nxv2i64/nxv16i8 , we cannot directly lower it to a (u|s)dot. We can
29538
+ /// of (nx)v2i64/(nx)v16i8 , we cannot directly lower it to a (u|s)dot. We can
29532
29539
/// however still make use of the dot product instruction by instead
29533
- /// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29540
+ /// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
29541
+ /// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
29542
+ /// the following pattern is emitted:
29543
+ /// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
29544
+ /// NTy/2))))
29534
29545
SDValue
29535
29546
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29536
29547
SelectionDAG &DAG) const {
29548
+ bool Scalable = Op.getValueType().isScalableVector();
29549
+
29550
+ assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29551
+ "SVE or StreamingSVE must be available when using scalable vectors.");
29552
+ assert((Scalable || Subtarget->hasDotProd()) &&
29553
+ "Dotprod must be available when targeting NEON dot product "
29554
+ "instructions.");
29555
+
29537
29556
SDLoc DL(Op);
29538
29557
29539
29558
SDValue Acc = Op.getOperand(0);
29540
29559
SDValue LHS = Op.getOperand(1);
29541
29560
SDValue RHS = Op.getOperand(2);
29542
29561
EVT ResultVT = Op.getValueType();
29543
- assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
29544
29562
29545
- SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
29546
- DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
29563
+ assert((Scalable && ResultVT == MVT::nxv2i64 &&
29564
+ LHS.getValueType() == MVT::nxv16i8) ||
29565
+ (!Scalable && ResultVT == MVT::v2i64 &&
29566
+ LHS.getValueType() == MVT::v16i8));
29567
+
29568
+ EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29569
+ SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29570
+ DAG.getConstant(0, DL, DotVT), LHS, RHS);
29547
29571
29548
29572
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29549
- if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29573
+ if (Scalable &&
29574
+ (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
29550
29575
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29551
29576
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29552
29577
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29553
29578
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29554
29579
}
29555
29580
29556
- unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29557
- unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29558
- auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29559
- auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29560
- auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29561
- return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29581
+ // Fold (nx)v4i32 into (nx)v2i64
29582
+ auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29583
+ if (IsUnsigned) {
29584
+ DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29585
+ DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29586
+ } else {
29587
+ DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29588
+ DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29589
+ }
29590
+ auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29591
+ return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
29562
29592
}
29563
29593
29564
29594
SDValue
0 commit comments