Skip to content

Commit 3a93f70

Browse files
Dinar TemirbulatovDinar Temirbulatov
authored andcommitted
[AArch64][SVE] Improve code quality of vector unsigned add reduction.
For SVE we don't have to zero extend and sum part of the result before issuing UADDV instruction. Also this change allows to handle bigger than a legal vector type more efficiently and lower a fixed-length vector type to SVE's UADDV where appropriate.
1 parent 486d00e commit 3a93f70

File tree

4 files changed

+345
-28
lines changed

4 files changed

+345
-28
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17503,6 +17503,99 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1750317503
return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
1750417504
}
1750517505

17506+
static SDValue
17507+
performVecReduceAddZextCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
17508+
const AArch64TargetLowering &TLI) {
17509+
if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND)
17510+
return SDValue();
17511+
17512+
SelectionDAG &DAG = DCI.DAG;
17513+
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
17514+
SDNode *ZEXT = N->getOperand(0).getNode();
17515+
EVT VecVT = ZEXT->getOperand(0).getValueType();
17516+
SDLoc DL(N);
17517+
17518+
SDValue VecOp = ZEXT->getOperand(0);
17519+
VecVT = VecOp.getValueType();
17520+
bool IsScalableType = VecVT.isScalableVector();
17521+
17522+
if (TLI.isTypeLegal(VecVT)) {
17523+
if (!IsScalableType &&
17524+
!TLI.useSVEForFixedLengthVectorVT(
17525+
VecVT,
17526+
/*OverrideNEON=*/Subtarget.useSVEForFixedLengthVectors(VecVT)))
17527+
return SDValue();
17528+
17529+
if (!IsScalableType) {
17530+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VecVT);
17531+
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
17532+
}
17533+
VecVT = VecOp.getValueType();
17534+
EVT RdxVT = N->getValueType(0);
17535+
RdxVT = getPackedSVEVectorVT(RdxVT);
17536+
SDValue Pg = getPredicateForVector(DAG, DL, VecVT);
17537+
SDValue Res = DAG.getNode(
17538+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
17539+
DAG.getConstant(Intrinsic::aarch64_sve_uaddv, DL, MVT::i64), Pg, VecOp);
17540+
EVT ResVT = MVT::i64;
17541+
if (ResVT != N->getValueType(0))
17542+
Res = DAG.getAnyExtOrTrunc(Res, DL, N->getValueType(0));
17543+
return Res;
17544+
}
17545+
17546+
SmallVector<SDValue, 4> SplitVals;
17547+
SmallVector<SDValue, 4> PrevVals;
17548+
PrevVals.push_back(VecOp);
17549+
while (true) {
17550+
17551+
if (!VecVT.isScalableVector() &&
17552+
!PrevVals[0].getValueType().getVectorElementCount().isKnownEven())
17553+
return SDValue();
17554+
17555+
for (SDValue Vec : PrevVals) {
17556+
SDValue Lo, Hi;
17557+
std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
17558+
SplitVals.push_back(Lo);
17559+
SplitVals.push_back(Hi);
17560+
}
17561+
if (TLI.isTypeLegal(SplitVals[0].getValueType()))
17562+
break;
17563+
PrevVals.clear();
17564+
std::copy(SplitVals.begin(), SplitVals.end(), std::back_inserter(PrevVals));
17565+
SplitVals.clear();
17566+
}
17567+
SDNode *VecRed = N;
17568+
EVT ElemType = VecRed->getValueType(0);
17569+
SmallVector<SDValue, 4> Results;
17570+
17571+
if (!IsScalableType &&
17572+
!TLI.useSVEForFixedLengthVectorVT(
17573+
SplitVals[0].getValueType(),
17574+
/*OverrideNEON=*/Subtarget.useSVEForFixedLengthVectors(
17575+
SplitVals[0].getValueType())))
17576+
return SDValue();
17577+
17578+
for (unsigned Num = 0; Num < SplitVals.size(); ++Num) {
17579+
SDValue Reg = SplitVals[Num];
17580+
EVT RdxVT = Reg->getValueType(0);
17581+
SDValue Pg = getPredicateForVector(DAG, DL, RdxVT);
17582+
if (!IsScalableType) {
17583+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, RdxVT);
17584+
Reg = convertToScalableVector(DAG, ContainerVT, Reg);
17585+
}
17586+
SDValue Res = DAG.getNode(
17587+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
17588+
DAG.getConstant(Intrinsic::aarch64_sve_uaddv, DL, MVT::i64), Pg, Reg);
17589+
if (ElemType != MVT::i64)
17590+
Res = DAG.getAnyExtOrTrunc(Res, DL, ElemType);
17591+
Results.push_back(Res);
17592+
}
17593+
SDValue ToAdd = Results[0];
17594+
for (unsigned I = 1; I < SplitVals.size(); ++I)
17595+
ToAdd = DAG.getNode(ISD::ADD, DL, ElemType, ToAdd, Results[I]);
17596+
return ToAdd;
17597+
}
17598+
1750617599
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
1750717600
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
1750817601
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
@@ -25188,8 +25281,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2518825281
return performInsertVectorEltCombine(N, DCI);
2518925282
case ISD::EXTRACT_VECTOR_ELT:
2519025283
return performExtractVectorEltCombine(N, DCI, Subtarget);
25191-
case ISD::VECREDUCE_ADD:
25192-
return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
25284+
case ISD::VECREDUCE_ADD: {
25285+
if (SDValue Val = performVecReduceAddCombine(N, DCI.DAG, Subtarget))
25286+
return Val;
25287+
return performVecReduceAddZextCombine(N, DCI, *this);
25288+
}
2519325289
case AArch64ISD::UADDV:
2519425290
return performUADDVCombine(N, DAG);
2519525291
case AArch64ISD::SMULL:

llvm/test/CodeGen/AArch64/sve-doublereduct.ll

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,12 @@ define i32 @add_i32(<vscale x 8 x i32> %a, <vscale x 4 x i32> %b) {
103103
define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
104104
; CHECK-LABEL: add_ext_i16:
105105
; CHECK: // %bb.0:
106-
; CHECK-NEXT: uunpkhi z2.h, z0.b
107-
; CHECK-NEXT: uunpklo z0.h, z0.b
108-
; CHECK-NEXT: uunpkhi z3.h, z1.b
109-
; CHECK-NEXT: uunpklo z1.h, z1.b
110-
; CHECK-NEXT: ptrue p0.h
111-
; CHECK-NEXT: add z0.h, z0.h, z2.h
112-
; CHECK-NEXT: add z1.h, z1.h, z3.h
113-
; CHECK-NEXT: add z0.h, z0.h, z1.h
114-
; CHECK-NEXT: uaddv d0, p0, z0.h
115-
; CHECK-NEXT: fmov x0, d0
116-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
106+
; CHECK-NEXT: ptrue p0.b
107+
; CHECK-NEXT: uaddv d0, p0, z0.b
108+
; CHECK-NEXT: uaddv d1, p0, z1.b
109+
; CHECK-NEXT: fmov w8, s0
110+
; CHECK-NEXT: fmov w9, s1
111+
; CHECK-NEXT: add w0, w8, w9
117112
; CHECK-NEXT: ret
118113
%ae = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
119114
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
@@ -126,21 +121,15 @@ define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
126121
define i16 @add_ext_v32i16(<vscale x 32 x i8> %a, <vscale x 16 x i8> %b) {
127122
; CHECK-LABEL: add_ext_v32i16:
128123
; CHECK: // %bb.0:
129-
; CHECK-NEXT: uunpklo z3.h, z1.b
130-
; CHECK-NEXT: uunpklo z4.h, z0.b
131-
; CHECK-NEXT: uunpkhi z1.h, z1.b
132-
; CHECK-NEXT: uunpkhi z0.h, z0.b
133-
; CHECK-NEXT: uunpkhi z5.h, z2.b
134-
; CHECK-NEXT: uunpklo z2.h, z2.b
135-
; CHECK-NEXT: ptrue p0.h
136-
; CHECK-NEXT: add z0.h, z0.h, z1.h
137-
; CHECK-NEXT: add z1.h, z4.h, z3.h
138-
; CHECK-NEXT: add z0.h, z1.h, z0.h
139-
; CHECK-NEXT: add z1.h, z2.h, z5.h
140-
; CHECK-NEXT: add z0.h, z0.h, z1.h
141-
; CHECK-NEXT: uaddv d0, p0, z0.h
142-
; CHECK-NEXT: fmov x0, d0
143-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
124+
; CHECK-NEXT: ptrue p0.b
125+
; CHECK-NEXT: uaddv d1, p0, z1.b
126+
; CHECK-NEXT: uaddv d0, p0, z0.b
127+
; CHECK-NEXT: uaddv d2, p0, z2.b
128+
; CHECK-NEXT: fmov w8, s1
129+
; CHECK-NEXT: fmov w9, s0
130+
; CHECK-NEXT: add w8, w9, w8
131+
; CHECK-NEXT: fmov w9, s2
132+
; CHECK-NEXT: add w0, w8, w9
144133
; CHECK-NEXT: ret
145134
%ae = zext <vscale x 32 x i8> %a to <vscale x 32 x i16>
146135
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>

llvm/test/CodeGen/AArch64/sve-int-reduce.ll

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,103 @@ define i64 @uaddv_nxv2i64(<vscale x 2 x i64> %a) {
188188
ret i64 %res
189189
}
190190

191+
define i32 @uaddv_nxv16i8_nxv16i32(<vscale x 16 x i8> %a) {
192+
; CHECK-LABEL: uaddv_nxv16i8_nxv16i32:
193+
; CHECK: // %bb.0:
194+
; CHECK-NEXT: ptrue p0.b
195+
; CHECK-NEXT: uaddv d0, p0, z0.b
196+
; CHECK-NEXT: fmov x0, d0
197+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
198+
; CHECK-NEXT: ret
199+
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
200+
%2 = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %1)
201+
ret i32 %2
202+
}
203+
204+
define i64 @uaddv_nxv16i16_nxv16i64(<vscale x 16 x i16> %a) {
205+
; CHECK-LABEL: uaddv_nxv16i16_nxv16i64:
206+
; CHECK: // %bb.0:
207+
; CHECK-NEXT: ptrue p0.h
208+
; CHECK-NEXT: uaddv d1, p0, z1.h
209+
; CHECK-NEXT: uaddv d0, p0, z0.h
210+
; CHECK-NEXT: fmov x8, d1
211+
; CHECK-NEXT: fmov x9, d0
212+
; CHECK-NEXT: add x0, x9, x8
213+
; CHECK-NEXT: ret
214+
%1 = zext <vscale x 16 x i16> %a to <vscale x 16 x i64>
215+
%2 = call i64 @llvm.vector.reduce.add.nxv16i64(<vscale x 16 x i64> %1)
216+
ret i64 %2
217+
}
218+
219+
define i32 @uaddv_nxv16i16_nxv16i32(<vscale x 32 x i16> %a) {
220+
; CHECK-LABEL: uaddv_nxv16i16_nxv16i32:
221+
; CHECK: // %bb.0:
222+
; CHECK-NEXT: ptrue p0.h
223+
; CHECK-NEXT: uaddv d1, p0, z1.h
224+
; CHECK-NEXT: uaddv d0, p0, z0.h
225+
; CHECK-NEXT: uaddv d2, p0, z2.h
226+
; CHECK-NEXT: uaddv d3, p0, z3.h
227+
; CHECK-NEXT: fmov w8, s1
228+
; CHECK-NEXT: fmov w9, s0
229+
; CHECK-NEXT: add w8, w9, w8
230+
; CHECK-NEXT: fmov w9, s2
231+
; CHECK-NEXT: add w8, w8, w9
232+
; CHECK-NEXT: fmov w9, s3
233+
; CHECK-NEXT: add w0, w8, w9
234+
; CHECK-NEXT: ret
235+
%1 = zext <vscale x 32 x i16> %a to <vscale x 32 x i32>
236+
%2 = call i32 @llvm.vector.reduce.add.nxv32i64(<vscale x 32 x i32> %1)
237+
ret i32 %2
238+
}
239+
240+
define i32 @saddv_nxv16i8_nxv16i32(<vscale x 16 x i8> %a) {
241+
; CHECK-LABEL: saddv_nxv16i8_nxv16i32:
242+
; CHECK: // %bb.0:
243+
; CHECK-NEXT: sunpkhi z1.h, z0.b
244+
; CHECK-NEXT: sunpklo z0.h, z0.b
245+
; CHECK-NEXT: ptrue p0.s
246+
; CHECK-NEXT: sunpklo z2.s, z1.h
247+
; CHECK-NEXT: sunpklo z3.s, z0.h
248+
; CHECK-NEXT: sunpkhi z1.s, z1.h
249+
; CHECK-NEXT: sunpkhi z0.s, z0.h
250+
; CHECK-NEXT: add z0.s, z0.s, z1.s
251+
; CHECK-NEXT: add z1.s, z3.s, z2.s
252+
; CHECK-NEXT: add z0.s, z1.s, z0.s
253+
; CHECK-NEXT: uaddv d0, p0, z0.s
254+
; CHECK-NEXT: fmov x0, d0
255+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
256+
; CHECK-NEXT: ret
257+
%1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
258+
%2 = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %1)
259+
ret i32 %2
260+
}
261+
262+
define i32 @uaddv_nxv32i16_nxv32i32(ptr %a) {
263+
; CHECK-LABEL: uaddv_nxv32i16_nxv32i32:
264+
; CHECK: // %bb.0:
265+
; CHECK-NEXT: ptrue p0.h
266+
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0, #1, mul vl]
267+
; CHECK-NEXT: ld1h { z1.h }, p0/z, [x0]
268+
; CHECK-NEXT: ld1h { z2.h }, p0/z, [x0, #2, mul vl]
269+
; CHECK-NEXT: ld1h { z3.h }, p0/z, [x0, #3, mul vl]
270+
; CHECK-NEXT: uaddv d0, p0, z0.h
271+
; CHECK-NEXT: uaddv d1, p0, z1.h
272+
; CHECK-NEXT: uaddv d2, p0, z2.h
273+
; CHECK-NEXT: uaddv d3, p0, z3.h
274+
; CHECK-NEXT: fmov w8, s0
275+
; CHECK-NEXT: fmov w9, s1
276+
; CHECK-NEXT: add w8, w9, w8
277+
; CHECK-NEXT: fmov w9, s2
278+
; CHECK-NEXT: add w8, w8, w9
279+
; CHECK-NEXT: fmov w9, s3
280+
; CHECK-NEXT: add w0, w8, w9
281+
; CHECK-NEXT: ret
282+
%1 = load <vscale x 32 x i16>, ptr %a, align 16
283+
%2 = zext <vscale x 32 x i16> %1 to <vscale x 32 x i32>
284+
%3 = call i32 @llvm.vector.reduce.add.nxv32i32(<vscale x 32 x i32> %2)
285+
ret i32 %3
286+
}
287+
191288
; UMINV
192289

193290
define i8 @umin_nxv16i8(<vscale x 16 x i8> %a) {

0 commit comments

Comments
 (0)