Skip to content

Commit c3c2e1e

Browse files
[AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions (llvm#114406)
For partial reductions in the situation of the number of elements being halved, a pair of wide add instructions can be used.
1 parent e05d91b commit c3c2e1e

File tree

2 files changed

+199
-2
lines changed

2 files changed

+199
-2
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,8 +2039,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20392039
return true;
20402040

20412041
EVT VT = EVT::getEVT(I->getType());
2042-
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2043-
VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
2042+
auto Op1 = I->getOperand(1);
2043+
EVT Op1VT = EVT::getEVT(Op1->getType());
2044+
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2045+
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2046+
VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2047+
return false;
2048+
return true;
20442049
}
20452050

20462051
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21784,6 +21789,55 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2178421789
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2178521790
}
2178621791

21792+
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
21793+
const AArch64Subtarget *Subtarget,
21794+
SelectionDAG &DAG) {
21795+
21796+
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
21797+
getIntrinsicID(N) ==
21798+
Intrinsic::experimental_vector_partial_reduce_add &&
21799+
"Expected a partial reduction node");
21800+
21801+
if (!Subtarget->isSVEorStreamingSVEAvailable())
21802+
return SDValue();
21803+
21804+
SDLoc DL(N);
21805+
21806+
auto Acc = N->getOperand(1);
21807+
auto ExtInput = N->getOperand(2);
21808+
21809+
EVT AccVT = Acc.getValueType();
21810+
EVT AccElemVT = AccVT.getVectorElementType();
21811+
21812+
if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
21813+
return SDValue();
21814+
21815+
unsigned ExtInputOpcode = ExtInput->getOpcode();
21816+
if (!ISD::isExtOpcode(ExtInputOpcode))
21817+
return SDValue();
21818+
21819+
auto Input = ExtInput->getOperand(0);
21820+
EVT InputVT = Input.getValueType();
21821+
21822+
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
21823+
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
21824+
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
21825+
return SDValue();
21826+
21827+
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
21828+
auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
21829+
: Intrinsic::aarch64_sve_uaddwb;
21830+
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
21831+
: Intrinsic::aarch64_sve_uaddwt;
21832+
21833+
auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
21834+
auto BottomNode =
21835+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
21836+
auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
21837+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
21838+
Input);
21839+
}
21840+
2178721841
static SDValue performIntrinsicCombine(SDNode *N,
2178821842
TargetLowering::DAGCombinerInfo &DCI,
2178921843
const AArch64Subtarget *Subtarget) {
@@ -21795,6 +21849,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
2179521849
case Intrinsic::experimental_vector_partial_reduce_add: {
2179621850
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
2179721851
return Dot;
21852+
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
21853+
return WideAdd;
2179821854
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
2179921855
N->getOperand(1), N->getOperand(2));
2180021856
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
3+
4+
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
5+
; CHECK-LABEL: signed_wide_add_nxv4i32:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: saddwb z0.d, z0.d, z1.s
8+
; CHECK-NEXT: saddwt z0.d, z0.d, z1.s
9+
; CHECK-NEXT: ret
10+
entry:
11+
%input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
12+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
13+
ret <vscale x 2 x i64> %partial.reduce
14+
}
15+
16+
define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
17+
; CHECK-LABEL: unsigned_wide_add_nxv4i32:
18+
; CHECK: // %bb.0: // %entry
19+
; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s
20+
; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s
21+
; CHECK-NEXT: ret
22+
entry:
23+
%input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
24+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
25+
ret <vscale x 2 x i64> %partial.reduce
26+
}
27+
28+
define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
29+
; CHECK-LABEL: signed_wide_add_nxv8i16:
30+
; CHECK: // %bb.0: // %entry
31+
; CHECK-NEXT: saddwb z0.s, z0.s, z1.h
32+
; CHECK-NEXT: saddwt z0.s, z0.s, z1.h
33+
; CHECK-NEXT: ret
34+
entry:
35+
%input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
36+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
37+
ret <vscale x 4 x i32> %partial.reduce
38+
}
39+
40+
define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
41+
; CHECK-LABEL: unsigned_wide_add_nxv8i16:
42+
; CHECK: // %bb.0: // %entry
43+
; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h
44+
; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h
45+
; CHECK-NEXT: ret
46+
entry:
47+
%input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
48+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
49+
ret <vscale x 4 x i32> %partial.reduce
50+
}
51+
52+
define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
53+
; CHECK-LABEL: signed_wide_add_nxv16i8:
54+
; CHECK: // %bb.0: // %entry
55+
; CHECK-NEXT: saddwb z0.h, z0.h, z1.b
56+
; CHECK-NEXT: saddwt z0.h, z0.h, z1.b
57+
; CHECK-NEXT: ret
58+
entry:
59+
%input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
60+
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
61+
ret <vscale x 8 x i16> %partial.reduce
62+
}
63+
64+
define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
65+
; CHECK-LABEL: unsigned_wide_add_nxv16i8:
66+
; CHECK: // %bb.0: // %entry
67+
; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b
68+
; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b
69+
; CHECK-NEXT: ret
70+
entry:
71+
%input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
72+
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
73+
ret <vscale x 8 x i16> %partial.reduce
74+
}
75+
76+
define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
77+
; CHECK-LABEL: signed_wide_add_nxv4i16:
78+
; CHECK: // %bb.0: // %entry
79+
; CHECK-NEXT: ptrue p0.s
80+
; CHECK-NEXT: sxth z1.s, p0/m, z1.s
81+
; CHECK-NEXT: uunpklo z2.d, z1.s
82+
; CHECK-NEXT: uunpkhi z1.d, z1.s
83+
; CHECK-NEXT: add z0.d, z0.d, z2.d
84+
; CHECK-NEXT: add z0.d, z1.d, z0.d
85+
; CHECK-NEXT: ret
86+
entry:
87+
%input.wide = sext <vscale x 4 x i16> %input to <vscale x 4 x i32>
88+
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
89+
ret <vscale x 2 x i32> %partial.reduce
90+
}
91+
92+
define <vscale x 2 x i32> @unsigned_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
93+
; CHECK-LABEL: unsigned_wide_add_nxv4i16:
94+
; CHECK: // %bb.0: // %entry
95+
; CHECK-NEXT: and z1.s, z1.s, #0xffff
96+
; CHECK-NEXT: uunpklo z2.d, z1.s
97+
; CHECK-NEXT: uunpkhi z1.d, z1.s
98+
; CHECK-NEXT: add z0.d, z0.d, z2.d
99+
; CHECK-NEXT: add z0.d, z1.d, z0.d
100+
; CHECK-NEXT: ret
101+
entry:
102+
%input.wide = zext <vscale x 4 x i16> %input to <vscale x 4 x i32>
103+
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
104+
ret <vscale x 2 x i32> %partial.reduce
105+
}
106+
107+
define <vscale x 4 x i64> @signed_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
108+
; CHECK-LABEL: signed_wide_add_nxv8i32:
109+
; CHECK: // %bb.0: // %entry
110+
; CHECK-NEXT: sunpkhi z4.d, z2.s
111+
; CHECK-NEXT: sunpklo z2.d, z2.s
112+
; CHECK-NEXT: sunpkhi z5.d, z3.s
113+
; CHECK-NEXT: sunpklo z3.d, z3.s
114+
; CHECK-NEXT: add z0.d, z0.d, z2.d
115+
; CHECK-NEXT: add z1.d, z1.d, z4.d
116+
; CHECK-NEXT: add z0.d, z3.d, z0.d
117+
; CHECK-NEXT: add z1.d, z5.d, z1.d
118+
; CHECK-NEXT: ret
119+
entry:
120+
%input.wide = sext <vscale x 8 x i32> %input to <vscale x 8 x i64>
121+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
122+
ret <vscale x 4 x i64> %partial.reduce
123+
}
124+
125+
define <vscale x 4 x i64> @unsigned_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
126+
; CHECK-LABEL: unsigned_wide_add_nxv8i32:
127+
; CHECK: // %bb.0: // %entry
128+
; CHECK-NEXT: uunpkhi z4.d, z2.s
129+
; CHECK-NEXT: uunpklo z2.d, z2.s
130+
; CHECK-NEXT: uunpkhi z5.d, z3.s
131+
; CHECK-NEXT: uunpklo z3.d, z3.s
132+
; CHECK-NEXT: add z0.d, z0.d, z2.d
133+
; CHECK-NEXT: add z1.d, z1.d, z4.d
134+
; CHECK-NEXT: add z0.d, z3.d, z0.d
135+
; CHECK-NEXT: add z1.d, z5.d, z1.d
136+
; CHECK-NEXT: ret
137+
entry:
138+
%input.wide = zext <vscale x 8 x i32> %input to <vscale x 8 x i64>
139+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
140+
ret <vscale x 4 x i64> %partial.reduce
141+
}

0 commit comments

Comments
 (0)