Skip to content

Commit ffa5bb1

Browse files
committed
[AArch64][SVE] Fold zero-extend into add reduction.
The original pull-request #97339 from @dtemirbulatov got reverted due to some regressions with NEON. This PR removes the additional type legalisation for vectors that are too wide, which should make this change less contentious.
1 parent 12937b1 commit ffa5bb1

File tree

5 files changed

+121
-25
lines changed

5 files changed

+121
-25
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17858,6 +17858,44 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1785817858
return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
1785917859
}
1786017860

17861+
// Turn vecreduce_add(zext/sext(...)) into SVE's [US]ADDV instruction.
17862+
static SDValue
17863+
performVecReduceAddExtCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
17864+
const AArch64TargetLowering &TLI) {
17865+
if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
17866+
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND)
17867+
return SDValue();
17868+
17869+
SDValue VecOp = N->getOperand(0).getOperand(0);
17870+
EVT VecOpVT = VecOp.getValueType();
17871+
if (VecOpVT.getScalarType() == MVT::i1 || !TLI.isTypeLegal(VecOpVT) ||
17872+
(VecOpVT.isFixedLengthVector() &&
17873+
!TLI.useSVEForFixedLengthVectorVT(VecOpVT, /*OverrideNEON=*/true)))
17874+
return SDValue();
17875+
17876+
SDLoc DL(N);
17877+
SelectionDAG &DAG = DCI.DAG;
17878+
17879+
// The input type is legal so map VECREDUCE_ADD to UADDV/SADDV, e.g.
17880+
// i32 (vecreduce_add (zext nxv16i8 %op to nxv16i32))
17881+
// ->
17882+
// i32 (UADDV nxv16i8:%op)
17883+
EVT ElemType = N->getValueType(0);
17884+
SDValue Pg = getPredicateForVector(DAG, DL, VecOpVT);
17885+
if (VecOpVT.isFixedLengthVector()) {
17886+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VecOpVT);
17887+
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
17888+
}
17889+
bool IsSigned = N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND;
17890+
SDValue Res =
17891+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
17892+
DAG.getConstant(IsSigned ? Intrinsic::aarch64_sve_saddv
17893+
: Intrinsic::aarch64_sve_uaddv,
17894+
DL, MVT::i64),
17895+
Pg, VecOp);
17896+
return DAG.getAnyExtOrTrunc(Res, DL, ElemType);
17897+
}
17898+
1786117899
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
1786217900
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
1786317901
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
@@ -25546,8 +25584,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2554625584
return performInsertVectorEltCombine(N, DCI);
2554725585
case ISD::EXTRACT_VECTOR_ELT:
2554825586
return performExtractVectorEltCombine(N, DCI, Subtarget);
25549-
case ISD::VECREDUCE_ADD:
25550-
return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
25587+
case ISD::VECREDUCE_ADD: {
25588+
if (SDValue Val = performVecReduceAddCombine(N, DCI.DAG, Subtarget))
25589+
return Val;
25590+
return performVecReduceAddExtCombine(N, DCI, *this);
25591+
}
2555125592
case AArch64ISD::UADDV:
2555225593
return performUADDVCombine(N, DAG);
2555325594
case AArch64ISD::SMULL:

llvm/test/CodeGen/AArch64/sve-doublereduct.ll

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,12 @@ define i32 @add_i32(<vscale x 8 x i32> %a, <vscale x 4 x i32> %b) {
103103
define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
104104
; CHECK-LABEL: add_ext_i16:
105105
; CHECK: // %bb.0:
106-
; CHECK-NEXT: uunpkhi z2.h, z0.b
107-
; CHECK-NEXT: uunpklo z0.h, z0.b
108-
; CHECK-NEXT: uunpkhi z3.h, z1.b
109-
; CHECK-NEXT: uunpklo z1.h, z1.b
110-
; CHECK-NEXT: ptrue p0.h
111-
; CHECK-NEXT: add z0.h, z0.h, z2.h
112-
; CHECK-NEXT: add z1.h, z1.h, z3.h
113-
; CHECK-NEXT: add z0.h, z0.h, z1.h
114-
; CHECK-NEXT: uaddv d0, p0, z0.h
115-
; CHECK-NEXT: fmov x0, d0
116-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
106+
; CHECK-NEXT: ptrue p0.b
107+
; CHECK-NEXT: uaddv d0, p0, z0.b
108+
; CHECK-NEXT: uaddv d1, p0, z1.b
109+
; CHECK-NEXT: fmov w8, s0
110+
; CHECK-NEXT: fmov w9, s1
111+
; CHECK-NEXT: add w0, w8, w9
117112
; CHECK-NEXT: ret
118113
%ae = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
119114
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
@@ -130,17 +125,16 @@ define i16 @add_ext_v32i16(<vscale x 32 x i8> %a, <vscale x 16 x i8> %b) {
130125
; CHECK-NEXT: uunpklo z4.h, z0.b
131126
; CHECK-NEXT: uunpkhi z1.h, z1.b
132127
; CHECK-NEXT: uunpkhi z0.h, z0.b
133-
; CHECK-NEXT: uunpkhi z5.h, z2.b
134-
; CHECK-NEXT: uunpklo z2.h, z2.b
135128
; CHECK-NEXT: ptrue p0.h
129+
; CHECK-NEXT: ptrue p1.b
136130
; CHECK-NEXT: add z0.h, z0.h, z1.h
137131
; CHECK-NEXT: add z1.h, z4.h, z3.h
138132
; CHECK-NEXT: add z0.h, z1.h, z0.h
139-
; CHECK-NEXT: add z1.h, z2.h, z5.h
140-
; CHECK-NEXT: add z0.h, z0.h, z1.h
133+
; CHECK-NEXT: uaddv d1, p1, z2.b
141134
; CHECK-NEXT: uaddv d0, p0, z0.h
142-
; CHECK-NEXT: fmov x0, d0
143-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
135+
; CHECK-NEXT: fmov w9, s1
136+
; CHECK-NEXT: fmov x8, d0
137+
; CHECK-NEXT: add w0, w8, w9
144138
; CHECK-NEXT: ret
145139
%ae = zext <vscale x 32 x i8> %a to <vscale x 32 x i16>
146140
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>

llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,34 @@ define i64 @uaddv_v32i64(ptr %a) vscale_range(16,0) #0 {
364364
ret i64 %res
365365
}
366366

367+
define i64 @uaddv_zext_v32i8_v32i64(ptr %a) vscale_range(2,0) #0 {
368+
; CHECK-LABEL: uaddv_zext_v32i8_v32i64:
369+
; CHECK: // %bb.0:
370+
; CHECK-NEXT: ptrue p0.b, vl32
371+
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
372+
; CHECK-NEXT: uaddv d0, p0, z0.b
373+
; CHECK-NEXT: fmov x0, d0
374+
; CHECK-NEXT: ret
375+
%op = load <32 x i8>, ptr %a
376+
%op.zext = zext <32 x i8> %op to <32 x i64>
377+
%res = call i64 @llvm.vector.reduce.add.v32i64(<32 x i64> %op.zext)
378+
ret i64 %res
379+
}
380+
381+
define i64 @uaddv_zext_v64i8_v64i64(ptr %a) vscale_range(4,0) #0 {
382+
; CHECK-LABEL: uaddv_zext_v64i8_v64i64:
383+
; CHECK: // %bb.0:
384+
; CHECK-NEXT: ptrue p0.b, vl64
385+
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
386+
; CHECK-NEXT: uaddv d0, p0, z0.b
387+
; CHECK-NEXT: fmov x0, d0
388+
; CHECK-NEXT: ret
389+
%op = load <64 x i8>, ptr %a
390+
%op.zext = zext <64 x i8> %op to <64 x i64>
391+
%res = call i64 @llvm.vector.reduce.add.v64i64(<64 x i64> %op.zext)
392+
ret i64 %res
393+
}
394+
367395
;
368396
; SMAXV
369397
;

llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=256 -verify-machineinstrs | FileCheck %s --check-prefixes=SVE256
32
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=128 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
43
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-n1 -O3 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
@@ -9,11 +8,7 @@ define internal i32 @test(ptr nocapture readonly %p1, i32 %i1, ptr nocapture rea
98
; SVE256: ld1b { z0.h }, p0/z,
109
; SVE256: ld1b { z1.h }, p0/z,
1110
; SVE256: sub z0.h, z0.h, z1.h
12-
; SVE256-NEXT: sunpklo z1.s, z0.h
13-
; SVE256-NEXT: ext z0.b, z0.b, z0.b, #16
14-
; SVE256-NEXT: sunpklo z0.s, z0.h
15-
; SVE256-NEXT: add z0.s, z1.s, z0.s
16-
; SVE256-NEXT: uaddv d0, p1, z0.s
11+
; SVE256-NEXT: saddv d0, p0, z0.h
1712

1813
; NEON-LABEL: test:
1914
; NEON: ldr q0, [x0, w9, sxtw]

llvm/test/CodeGen/AArch64/sve-int-reduce.ll

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,44 @@ define i64 @uaddv_nxv2i64(<vscale x 2 x i64> %a) {
188188
ret i64 %res
189189
}
190190

191+
define i16 @uaddv_nxv16i8_nxv16i16(<vscale x 16 x i8> %a) {
192+
; CHECK-LABEL: uaddv_nxv16i8_nxv16i16:
193+
; CHECK: // %bb.0:
194+
; CHECK-NEXT: ptrue p0.b
195+
; CHECK-NEXT: uaddv d0, p0, z0.b
196+
; CHECK-NEXT: fmov x0, d0
197+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
198+
; CHECK-NEXT: ret
199+
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
200+
%2 = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> %1)
201+
ret i16 %2
202+
}
203+
204+
define i32 @uaddv_nxv16i8_nxv16i32(<vscale x 16 x i8> %a) {
205+
; CHECK-LABEL: uaddv_nxv16i8_nxv16i32:
206+
; CHECK: // %bb.0:
207+
; CHECK-NEXT: ptrue p0.b
208+
; CHECK-NEXT: uaddv d0, p0, z0.b
209+
; CHECK-NEXT: fmov x0, d0
210+
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
211+
; CHECK-NEXT: ret
212+
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
213+
%2 = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %1)
214+
ret i32 %2
215+
}
216+
217+
define i64 @uaddv_nxv16i8_nxv16i64(<vscale x 16 x i8> %a) {
218+
; CHECK-LABEL: uaddv_nxv16i8_nxv16i64:
219+
; CHECK: // %bb.0:
220+
; CHECK-NEXT: ptrue p0.b
221+
; CHECK-NEXT: uaddv d0, p0, z0.b
222+
; CHECK-NEXT: fmov x0, d0
223+
; CHECK-NEXT: ret
224+
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
225+
%2 = call i64 @llvm.vector.reduce.add.nxv16i64(<vscale x 16 x i64> %1)
226+
ret i64 %2
227+
}
228+
191229
; UMINV
192230

193231
define i8 @umin_nxv16i8(<vscale x 16 x i8> %a) {

0 commit comments

Comments
 (0)