Skip to content

Commit 26bae79

Browse files
[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes (#140075)
Lowering for fixed width vectors added to tablegen. There is also custom lowering to ensure that the USDOT patterns are still lowered for fixed width vectors. It also ensures that the v16i8 -> v4i64 partial reduction case is lowered here instead of being split (as there is not a v2i64 dot product instruction). @JamesChesterman is the original author. --------- Co-authored-by: James Chesterman <[email protected]>
1 parent dc6aac5 commit 26bae79

File tree

4 files changed

+392
-70
lines changed

4 files changed

+392
-70
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14561456
// FADDP custom lowering
14571457
for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
14581458
setOperationAction(ISD::FADD, VT, Custom);
1459+
1460+
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1461+
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1462+
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1463+
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1464+
}
1465+
14591466
} else /* !isNeonAvailable */ {
14601467
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
14611468
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
@@ -29528,37 +29535,60 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2952829535
}
2952929536

2953029537
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29531-
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
29538+
/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
2953229539
/// however still make use of the dot product instruction by instead
29533-
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29540+
/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
29541+
/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
29542+
/// the following pattern is emitted:
29543+
/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
29544+
/// NTy/2))))
2953429545
SDValue
2953529546
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2953629547
SelectionDAG &DAG) const {
29548+
bool Scalable = Op.getValueType().isScalableVector();
29549+
29550+
assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
29551+
"SVE or StreamingSVE must be available when using scalable vectors.");
29552+
assert((Scalable || Subtarget->hasDotProd()) &&
29553+
"Dotprod must be available when targeting NEON dot product "
29554+
"instructions.");
29555+
2953729556
SDLoc DL(Op);
2953829557

2953929558
SDValue Acc = Op.getOperand(0);
2954029559
SDValue LHS = Op.getOperand(1);
2954129560
SDValue RHS = Op.getOperand(2);
2954229561
EVT ResultVT = Op.getValueType();
29543-
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
2954429562

29545-
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
29546-
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
29563+
assert((Scalable && ResultVT == MVT::nxv2i64 &&
29564+
LHS.getValueType() == MVT::nxv16i8) ||
29565+
(!Scalable && ResultVT == MVT::v2i64 &&
29566+
LHS.getValueType() == MVT::v16i8));
29567+
29568+
EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
29569+
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
29570+
DAG.getConstant(0, DL, DotVT), LHS, RHS);
2954729571

2954829572
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29549-
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29573+
if (Scalable &&
29574+
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
2955029575
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
2955129576
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
2955229577
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
2955329578
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
2955429579
}
2955529580

29556-
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29557-
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29558-
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29559-
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29560-
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29561-
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29581+
// Fold (nx)v4i32 into (nx)v2i64
29582+
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
29583+
if (IsUnsigned) {
29584+
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
29585+
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
29586+
} else {
29587+
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
29588+
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
29589+
}
29590+
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
29591+
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
2956229592
}
2956329593

2956429594
SDValue

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
14741474
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
14751475
}
14761476

1477+
let Predicates = [HasNEON, HasDotProd] in {
1478+
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
1479+
(v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
1480+
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
1481+
(v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
1482+
def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
1483+
(v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
1484+
def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
1485+
(v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
1486+
} // End HasNEON, HasDotProd
1487+
14771488
// ARMv8.6-A BFloat
14781489
let Predicates = [HasNEON, HasBF16] in {
14791490
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;

0 commit comments

Comments
 (0)