Skip to content

Commit 198f6b9

Browse files
SamTebbs33fhahn
authored andcommitted
[AArch64] Lower partial add reduction to udot or svdot (llvm#101010)
This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a `shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will return false from in the cases that it can be lowered.
1 parent 3874ca2 commit 198f6b9

File tree

7 files changed

+217
-25
lines changed

7 files changed

+217
-25
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,11 @@ class SelectionDAG {
15871587
/// the target's desired shift amount type.
15881588
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
15891589

1590+
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
1591+
/// its operands and ReducedTY is the intrinsic's return type.
1592+
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
1593+
SDValue Op2);
1594+
15901595
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
15911596
SDValue expandVAArg(SDNode *Node);
15921597

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,13 @@ class TargetLoweringBase {
454454
return true;
455455
}
456456

457+
/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
458+
/// should be expanded using generic code in SelectionDAGBuilder.
459+
virtual bool
460+
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
461+
return true;
462+
}
463+
457464
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
458465
/// using generic code in SelectionDAGBuilder.
459466
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include <cassert>
7575
#include <cstdint>
7676
#include <cstdlib>
77+
#include <deque>
7778
#include <limits>
7879
#include <optional>
7980
#include <set>
@@ -2411,6 +2412,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
24112412
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
24122413
}
24132414

2415+
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
2416+
SDValue Op2) {
2417+
EVT FullTy = Op2.getValueType();
2418+
2419+
unsigned Stride = ReducedTy.getVectorMinNumElements();
2420+
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
2421+
2422+
// Collect all of the subvectors
2423+
std::deque<SDValue> Subvectors = {Op1};
2424+
for (unsigned I = 0; I < ScaleFactor; I++) {
2425+
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
2426+
Subvectors.push_back(
2427+
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
2428+
}
2429+
2430+
// Flatten the subvector tree
2431+
while (Subvectors.size() > 1) {
2432+
Subvectors.push_back(
2433+
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
2434+
Subvectors.pop_front();
2435+
Subvectors.pop_front();
2436+
}
2437+
2438+
assert(Subvectors.size() == 1 &&
2439+
"There should only be one subvector after tree flattening");
2440+
2441+
return Subvectors[0];
2442+
}
2443+
24142444
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
24152445
SDLoc dl(Node);
24162446
const TargetLowering &TLI = getTargetLoweringInfo();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8000,34 +8000,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
80008000
return;
80018001
}
80028002
case Intrinsic::experimental_vector_partial_reduce_add: {
8003-
SDValue OpNode = getValue(I.getOperand(1));
8004-
EVT ReducedTy = EVT::getEVT(I.getType());
8005-
EVT FullTy = OpNode.getValueType();
80068003

8007-
unsigned Stride = ReducedTy.getVectorMinNumElements();
8008-
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
8009-
8010-
// Collect all of the subvectors
8011-
std::deque<SDValue> Subvectors;
8012-
Subvectors.push_back(getValue(I.getOperand(0)));
8013-
for (unsigned i = 0; i < ScaleFactor; i++) {
8014-
auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
8015-
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
8016-
{OpNode, SourceIndex}));
8017-
}
8018-
8019-
// Flatten the subvector tree
8020-
while (Subvectors.size() > 1) {
8021-
Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
8022-
{Subvectors[0], Subvectors[1]}));
8023-
Subvectors.pop_front();
8024-
Subvectors.pop_front();
8004+
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8005+
visitTargetIntrinsic(I, Intrinsic);
8006+
return;
80258007
}
80268008

8027-
assert(Subvectors.size() == 1 &&
8028-
"There should only be one subvector after tree flattening");
8029-
8030-
setValue(&I, Subvectors[0]);
8009+
setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
8010+
getValue(I.getOperand(0)),
8011+
getValue(I.getOperand(1))));
80318012
return;
80328013
}
80338014
case Intrinsic::experimental_cttz_elts: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,6 +1971,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
19711971
return false;
19721972
}
19731973

1974+
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
1975+
const IntrinsicInst *I) const {
1976+
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
1977+
return true;
1978+
1979+
EVT VT = EVT::getEVT(I->getType());
1980+
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
1981+
}
1982+
19741983
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
19751984
if (!Subtarget->isSVEorStreamingSVEAvailable())
19761985
return true;
@@ -21250,6 +21259,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
2125021259
return SDValue(N, 0);
2125121260
}
2125221261

21262+
SDValue tryLowerPartialReductionToDot(SDNode *N,
21263+
const AArch64Subtarget *Subtarget,
21264+
SelectionDAG &DAG) {
21265+
21266+
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
21267+
getIntrinsicID(N) ==
21268+
Intrinsic::experimental_vector_partial_reduce_add &&
21269+
"Expected a partial reduction node");
21270+
21271+
if (!Subtarget->isSVEorStreamingSVEAvailable())
21272+
return SDValue();
21273+
21274+
SDLoc DL(N);
21275+
21276+
// The narrower of the two operands. Used as the accumulator
21277+
auto NarrowOp = N->getOperand(1);
21278+
auto MulOp = N->getOperand(2);
21279+
if (MulOp->getOpcode() != ISD::MUL)
21280+
return SDValue();
21281+
21282+
auto ExtA = MulOp->getOperand(0);
21283+
auto ExtB = MulOp->getOperand(1);
21284+
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21285+
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
21286+
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
21287+
return SDValue();
21288+
21289+
auto A = ExtA->getOperand(0);
21290+
auto B = ExtB->getOperand(0);
21291+
if (A.getValueType() != B.getValueType())
21292+
return SDValue();
21293+
21294+
unsigned Opcode = 0;
21295+
21296+
if (IsSExt)
21297+
Opcode = AArch64ISD::SDOT;
21298+
else if (IsZExt)
21299+
Opcode = AArch64ISD::UDOT;
21300+
21301+
assert(Opcode != 0 && "Unexpected dot product case encountered.");
21302+
21303+
EVT ReducedType = N->getValueType(0);
21304+
EVT MulSrcType = A.getValueType();
21305+
21306+
// Dot products operate on chunks of four elements so there must be four times
21307+
// as many elements in the wide type
21308+
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
21309+
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
21310+
21311+
if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
21312+
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
21313+
21314+
return SDValue();
21315+
}
21316+
2125321317
static SDValue performIntrinsicCombine(SDNode *N,
2125421318
TargetLowering::DAGCombinerInfo &DCI,
2125521319
const AArch64Subtarget *Subtarget) {
@@ -21258,6 +21322,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
2125821322
switch (IID) {
2125921323
default:
2126021324
break;
21325+
case Intrinsic::experimental_vector_partial_reduce_add: {
21326+
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
21327+
return Dot;
21328+
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
21329+
N->getOperand(1), N->getOperand(2));
21330+
}
2126121331
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2126221332
case Intrinsic::aarch64_neon_vcvtfxu2fp:
2126321333
return tryCombineFixedPointConvert(N, DCI, DAG);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,9 @@ class AArch64TargetLowering : public TargetLowering {
991991

992992
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
993993

994+
bool
995+
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
996+
994997
bool shouldExpandCttzElements(EVT VT) const override;
995998

996999
/// If a change in streaming mode is required on entry to/return from a
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
3+
4+
define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
5+
; CHECK-LABEL: dotp:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: udot z0.s, z1.b, z2.b
8+
; CHECK-NEXT: ret
9+
entry:
10+
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
11+
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
12+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
13+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
14+
ret <vscale x 4 x i32> %partial.reduce
15+
}
16+
17+
define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
18+
; CHECK-LABEL: dotp_wide:
19+
; CHECK: // %bb.0: // %entry
20+
; CHECK-NEXT: udot z0.d, z1.h, z2.h
21+
; CHECK-NEXT: ret
22+
entry:
23+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
24+
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
25+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
26+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
27+
ret <vscale x 2 x i64> %partial.reduce
28+
}
29+
30+
define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
31+
; CHECK-LABEL: dotp_sext:
32+
; CHECK: // %bb.0: // %entry
33+
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
34+
; CHECK-NEXT: ret
35+
entry:
36+
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
37+
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
38+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
39+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
40+
ret <vscale x 4 x i32> %partial.reduce
41+
}
42+
43+
define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
44+
; CHECK-LABEL: dotp_wide_sext:
45+
; CHECK: // %bb.0: // %entry
46+
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
47+
; CHECK-NEXT: ret
48+
entry:
49+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
50+
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
51+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
52+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
53+
ret <vscale x 2 x i64> %partial.reduce
54+
}
55+
56+
define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
57+
; CHECK-LABEL: not_dotp:
58+
; CHECK: // %bb.0: // %entry
59+
; CHECK-NEXT: and z1.h, z1.h, #0xff
60+
; CHECK-NEXT: and z2.h, z2.h, #0xff
61+
; CHECK-NEXT: ptrue p0.s
62+
; CHECK-NEXT: uunpklo z3.s, z1.h
63+
; CHECK-NEXT: uunpklo z4.s, z2.h
64+
; CHECK-NEXT: uunpkhi z1.s, z1.h
65+
; CHECK-NEXT: uunpkhi z2.s, z2.h
66+
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
67+
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
68+
; CHECK-NEXT: ret
69+
entry:
70+
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
71+
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
72+
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
73+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
74+
ret <vscale x 4 x i32> %partial.reduce
75+
}
76+
77+
define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
78+
; CHECK-LABEL: not_dotp_wide:
79+
; CHECK: // %bb.0: // %entry
80+
; CHECK-NEXT: and z1.s, z1.s, #0xffff
81+
; CHECK-NEXT: and z2.s, z2.s, #0xffff
82+
; CHECK-NEXT: ptrue p0.d
83+
; CHECK-NEXT: uunpklo z3.d, z1.s
84+
; CHECK-NEXT: uunpklo z4.d, z2.s
85+
; CHECK-NEXT: uunpkhi z1.d, z1.s
86+
; CHECK-NEXT: uunpkhi z2.d, z2.s
87+
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
88+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
89+
; CHECK-NEXT: ret
90+
entry:
91+
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
92+
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
93+
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
94+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
95+
ret <vscale x 2 x i64> %partial.reduce
96+
}

0 commit comments

Comments
 (0)