Skip to content

Commit 421f1b7

Browse files
author
Cameron McInally
committed
[SVE] Lower fixed length VECREDUCE_FADD operation
Differential Revision: https://reviews.llvm.org/D89263
1 parent 94d9a4f commit 421f1b7

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11251125
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
11261126
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
11271127
}
1128+
1129+
// Use SVE for vectors with more than 2 elements.
1130+
for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
1131+
setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
11281132
}
11291133
}
11301134

@@ -1261,6 +1265,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
12611265
setOperationAction(ISD::UMIN, VT, Custom);
12621266
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
12631267
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1268+
setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
12641269
setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
12651270
setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
12661271
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
@@ -3963,6 +3968,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
39633968
case ISD::VECREDUCE_SMIN:
39643969
case ISD::VECREDUCE_UMAX:
39653970
case ISD::VECREDUCE_UMIN:
3971+
case ISD::VECREDUCE_FADD:
39663972
case ISD::VECREDUCE_FMAX:
39673973
case ISD::VECREDUCE_FMIN:
39683974
return LowerVECREDUCE(Op, DAG);
@@ -9749,6 +9755,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
97499755
bool OverrideNEON = Op.getOpcode() == ISD::VECREDUCE_AND ||
97509756
Op.getOpcode() == ISD::VECREDUCE_OR ||
97519757
Op.getOpcode() == ISD::VECREDUCE_XOR ||
9758+
Op.getOpcode() == ISD::VECREDUCE_FADD ||
97529759
(Op.getOpcode() != ISD::VECREDUCE_ADD &&
97539760
SrcVT.getVectorElementType() == MVT::i64);
97549761
if (useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
@@ -9769,6 +9776,8 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
97699776
return LowerFixedLengthReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
97709777
case ISD::VECREDUCE_XOR:
97719778
return LowerFixedLengthReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
9779+
case ISD::VECREDUCE_FADD:
9780+
return LowerFixedLengthReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
97729781
case ISD::VECREDUCE_FMAX:
97739782
return LowerFixedLengthReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
97749783
case ISD::VECREDUCE_FMIN:

llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,246 @@ target triple = "aarch64-unknown-linux-gnu"
2020
; Don't use SVE when its registers are no bigger than NEON.
2121
; NO_SVE-NOT: ptrue
2222

23+
;
24+
; FADDV
25+
;
26+
27+
; No single instruction NEON support for 4 element vectors.
28+
define half @faddv_v4f16(half %start, <4 x half> %a) #0 {
29+
; CHECK-LABEL: faddv_v4f16:
30+
; CHECK: ptrue [[PG:p[0-9]+]].h, vl4
31+
; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], z1.h
32+
; CHECK-NEXT: fadd h0, h0, [[RDX]]
33+
; CHECK-NEXT: ret
34+
%res = call fast half @llvm.vector.reduce.fadd.v4f16(half %start, <4 x half> %a)
35+
ret half %res
36+
}
37+
38+
; No single instruction NEON support for 8 element vectors.
39+
define half @faddv_v8f16(half %start, <8 x half> %a) #0 {
40+
; CHECK-LABEL: faddv_v8f16:
41+
; CHECK: ptrue [[PG:p[0-9]+]].h, vl8
42+
; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], z1.h
43+
; CHECK-NEXT: fadd h0, h0, [[RDX]]
44+
; CHECK-NEXT: ret
45+
%res = call fast half @llvm.vector.reduce.fadd.v8f16(half %start, <8 x half> %a)
46+
ret half %res
47+
}
48+
49+
define half @faddv_v16f16(half %start, <16 x half>* %a) #0 {
50+
; CHECK-LABEL: faddv_v16f16:
51+
; CHECK: ptrue [[PG:p[0-9]+]].h, vl16
52+
; CHECK-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0]
53+
; CHECK-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h
54+
; CHECK-NEXT: fadd h0, h0, [[RDX]]
55+
; CHECK-NEXT: ret
56+
%op = load <16 x half>, <16 x half>* %a
57+
%res = call fast half @llvm.vector.reduce.fadd.v16f16(half %start, <16 x half> %op)
58+
ret half %res
59+
}
60+
61+
define half @faddv_v32f16(half %start, <32 x half>* %a) #0 {
62+
; CHECK-LABEL: faddv_v32f16:
63+
; VBITS_GE_512: ptrue [[PG:p[0-9]+]].h, vl32
64+
; VBITS_GE_512-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0]
65+
; VBITS_GE_512-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h
66+
; VBITS_GE_512-NEXT: fadd h0, h0, [[RDX]]
67+
; VBITS_GE_512-NEXT: ret
68+
69+
; Ensure sensible type legalisation.
70+
; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].h, vl16
71+
; VBITS_EQ_256-DAG: add x[[A_HI:[0-9]+]], x0, #32
72+
; VBITS_EQ_256-DAG: ld1h { [[LO:z[0-9]+]].h }, [[PG]]/z, [x0]
73+
; VBITS_EQ_256-DAG: ld1h { [[HI:z[0-9]+]].h }, [[PG]]/z, [x[[A_HI]]]
74+
; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].h, [[PG]]/m, [[LO]].h, [[HI]].h
75+
; VBITS_EQ_256-DAG: faddv h1, [[PG]], [[ADD]].h
76+
; VBITS_EQ_256-DAG: fadd h0, h0, [[RDX]]
77+
; VBITS_EQ_256-NEXT: ret
78+
%op = load <32 x half>, <32 x half>* %a
79+
%res = call fast half @llvm.vector.reduce.fadd.v32f16(half %start, <32 x half> %op)
80+
ret half %res
81+
}
82+
83+
define half @faddv_v64f16(half %start, <64 x half>* %a) #0 {
84+
; CHECK-LABEL: faddv_v64f16:
85+
; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].h, vl64
86+
; VBITS_GE_1024-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0]
87+
; VBITS_GE_1024-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h
88+
; VBITS_GE_1024-NEXT: fadd h0, h0, [[RDX]]
89+
; VBITS_GE_1024-NEXT: ret
90+
%op = load <64 x half>, <64 x half>* %a
91+
%res = call fast half @llvm.vector.reduce.fadd.v64f16(half %start, <64 x half> %op)
92+
ret half %res
93+
}
94+
95+
define half @faddv_v128f16(half %start, <128 x half>* %a) #0 {
96+
; CHECK-LABEL: faddv_v128f16:
97+
; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].h, vl128
98+
; VBITS_GE_2048-NEXT: ld1h { [[OP:z[0-9]+]].h }, [[PG]]/z, [x0]
99+
; VBITS_GE_2048-NEXT: faddv [[RDX:h[0-9]+]], [[PG]], [[OP]].h
100+
; VBITS_GE_2048-NEXT: fadd h0, h0, [[RDX]]
101+
; VBITS_GE_2048-NEXT: ret
102+
%op = load <128 x half>, <128 x half>* %a
103+
%res = call fast half @llvm.vector.reduce.fadd.v128f16(half %start, <128 x half> %op)
104+
ret half %res
105+
}
106+
107+
; Don't use SVE for 2 element vectors.
108+
define float @faddv_v2f32(float %start, <2 x float> %a) #0 {
109+
; CHECK-LABEL: faddv_v2f32:
110+
; CHECK: faddp s1, v1.2s
111+
; CHECK-NEXT: fadd s0, s0, s1
112+
; CHECK-NEXT: ret
113+
%res = call fast float @llvm.vector.reduce.fadd.v2f32(float %start, <2 x float> %a)
114+
ret float %res
115+
}
116+
117+
; No single instruction NEON support for 4 element vectors.
118+
define float @faddv_v4f32(float %start, <4 x float> %a) #0 {
119+
; CHECK-LABEL: faddv_v4f32:
120+
; CHECK: ptrue [[PG:p[0-9]+]].s, vl4
121+
; CHECK-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], z1.s
122+
; CHECK-NEXT: fadd s0, s0, [[RDX]]
123+
; CHECK-NEXT: ret
124+
%res = call fast float @llvm.vector.reduce.fadd.v4f32(float %start, <4 x float> %a)
125+
ret float %res
126+
}
127+
128+
define float @faddv_v8f32(float %start, <8 x float>* %a) #0 {
129+
; CHECK-LABEL: faddv_v8f32:
130+
; CHECK: ptrue [[PG:p[0-9]+]].s, vl8
131+
; CHECK-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0]
132+
; CHECK-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s
133+
; CHECK-NEXT: fadd s0, s0, [[RDX]]
134+
; CHECK-NEXT: ret
135+
%op = load <8 x float>, <8 x float>* %a
136+
%res = call fast float @llvm.vector.reduce.fadd.v8f32(float %start, <8 x float> %op)
137+
ret float %res
138+
}
139+
140+
define float @faddv_v16f32(float %start, <16 x float>* %a) #0 {
141+
; CHECK-LABEL: faddv_v16f32:
142+
; VBITS_GE_512: ptrue [[PG:p[0-9]+]].s, vl16
143+
; VBITS_GE_512-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0]
144+
; VBITS_GE_512-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s
145+
; VBITS_GE_512-NEXT: fadd s0, s0, [[RDX]]
146+
; VBITS_GE_512-NEXT: ret
147+
148+
; Ensure sensible type legalisation.
149+
; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].s, vl8
150+
; VBITS_EQ_256-DAG: add x[[A_LO:[0-9]+]], x0, #32
151+
; VBITS_EQ_256-DAG: ld1w { [[LO:z[0-9]+]].s }, [[PG]]/z, [x0]
152+
; VBITS_EQ_256-DAG: ld1w { [[HI:z[0-9]+]].s }, [[PG]]/z, [x[[A_LO]]]
153+
; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].s, [[PG]]/m, [[LO]].s, [[HI]].s
154+
; VBITS_EQ_256-DAG: faddv [[RDX:s[0-9]+]], [[PG]], [[ADD]].s
155+
; VBITS_EQ_256-DAG: fadd s0, s0, [[RDX]]
156+
; VBITS_EQ_256-NEXT: ret
157+
%op = load <16 x float>, <16 x float>* %a
158+
%res = call fast float @llvm.vector.reduce.fadd.v16f32(float %start, <16 x float> %op)
159+
ret float %res
160+
}
161+
162+
define float @faddv_v32f32(float %start, <32 x float>* %a) #0 {
163+
; CHECK-LABEL: faddv_v32f32:
164+
; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].s, vl32
165+
; VBITS_GE_1024-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0]
166+
; VBITS_GE_1024-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s
167+
; VBITS_GE_1024-NEXT: fadd s0, s0, [[RDX]]
168+
; VBITS_GE_1024-NEXT: ret
169+
%op = load <32 x float>, <32 x float>* %a
170+
%res = call fast float @llvm.vector.reduce.fadd.v32f32(float %start, <32 x float> %op)
171+
ret float %res
172+
}
173+
174+
define float @faddv_v64f32(float %start, <64 x float>* %a) #0 {
175+
; CHECK-LABEL: faddv_v64f32:
176+
; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl64
177+
; VBITS_GE_2048-NEXT: ld1w { [[OP:z[0-9]+]].s }, [[PG]]/z, [x0]
178+
; VBITS_GE_2048-NEXT: faddv [[RDX:s[0-9]+]], [[PG]], [[OP]].s
179+
; VBITS_GE_2048-NEXT: fadd s0, s0, [[RDX]]
180+
; VBITS_GE_2048-NEXT: ret
181+
%op = load <64 x float>, <64 x float>* %a
182+
%res = call fast float @llvm.vector.reduce.fadd.v64f32(float %start, <64 x float> %op)
183+
ret float %res
184+
}
185+
186+
; Don't use SVE for 1 element vectors.
187+
define double @faddv_v1f64(double %start, <1 x double> %a) #0 {
188+
; CHECK-LABEL: faddv_v1f64:
189+
; CHECK: fadd d0, d0, d1
190+
; CHECK-NEXT: ret
191+
%res = call fast double @llvm.vector.reduce.fadd.v1f64(double %start, <1 x double> %a)
192+
ret double %res
193+
}
194+
195+
; Don't use SVE for 2 element vectors.
196+
define double @faddv_v2f64(double %start, <2 x double> %a) #0 {
197+
; CHECK-LABEL: faddv_v2f64:
198+
; CHECK: faddp d1, v1.2d
199+
; CHECK-NEXT: fadd d0, d0, d1
200+
; CHECK-NEXT: ret
201+
%res = call fast double @llvm.vector.reduce.fadd.v2f64(double %start, <2 x double> %a)
202+
ret double %res
203+
}
204+
205+
define double @faddv_v4f64(double %start, <4 x double>* %a) #0 {
206+
; CHECK-LABEL: faddv_v4f64:
207+
; CHECK: ptrue [[PG:p[0-9]+]].d, vl4
208+
; CHECK-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0]
209+
; CHECK-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d
210+
; CHECK-NEXT: fadd d0, d0, [[RDX]]
211+
; CHECK-NEXT: ret
212+
%op = load <4 x double>, <4 x double>* %a
213+
%res = call fast double @llvm.vector.reduce.fadd.v4f64(double %start, <4 x double> %op)
214+
ret double %res
215+
}
216+
217+
define double @faddv_v8f64(double %start, <8 x double>* %a) #0 {
218+
; CHECK-LABEL: faddv_v8f64:
219+
; VBITS_GE_512: ptrue [[PG:p[0-9]+]].d, vl8
220+
; VBITS_GE_512-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0]
221+
; VBITS_GE_512-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d
222+
; VBITS_GE_512-NEXT: fadd d0, d0, [[RDX]]
223+
; VBITS_GE_512-NEXT: ret
224+
225+
; Ensure sensible type legalisation.
226+
; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].d, vl4
227+
; VBITS_EQ_256-DAG: add x[[A_LO:[0-9]+]], x0, #32
228+
; VBITS_EQ_256-DAG: ld1d { [[LO:z[0-9]+]].d }, [[PG]]/z, [x0]
229+
; VBITS_EQ_256-DAG: ld1d { [[HI:z[0-9]+]].d }, [[PG]]/z, [x[[A_LO]]]
230+
; VBITS_EQ_256-DAG: fadd [[ADD:z[0-9]+]].d, [[PG]]/m, [[LO]].d, [[HI]].d
231+
; VBITS_EQ_256-DAG: faddv [[RDX:d[0-9]+]], [[PG]], [[ADD]].d
232+
; VBITS_EQ_256-DAG: fadd d0, d0, [[RDX]]
233+
; VBITS_EQ_256-NEXT: ret
234+
%op = load <8 x double>, <8 x double>* %a
235+
%res = call fast double @llvm.vector.reduce.fadd.v8f64(double %start, <8 x double> %op)
236+
ret double %res
237+
}
238+
239+
define double @faddv_v16f64(double %start, <16 x double>* %a) #0 {
240+
; CHECK-LABEL: faddv_v16f64:
241+
; VBITS_GE_1024: ptrue [[PG:p[0-9]+]].d, vl16
242+
; VBITS_GE_1024-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0]
243+
; VBITS_GE_1024-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d
244+
; VBITS_GE_1024-NEXT: fadd d0, d0, [[RDX]]
245+
; VBITS_GE_1024-NEXT: ret
246+
%op = load <16 x double>, <16 x double>* %a
247+
%res = call fast double @llvm.vector.reduce.fadd.v16f64(double %start, <16 x double> %op)
248+
ret double %res
249+
}
250+
251+
define double @faddv_v32f64(double %start, <32 x double>* %a) #0 {
252+
; CHECK-LABEL: faddv_v32f64:
253+
; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].d, vl32
254+
; VBITS_GE_2048-NEXT: ld1d { [[OP:z[0-9]+]].d }, [[PG]]/z, [x0]
255+
; VBITS_GE_2048-NEXT: faddv [[RDX:d[0-9]+]], [[PG]], [[OP]].d
256+
; VBITS_GE_2048-NEXT: fadd d0, d0, [[RDX]]
257+
; VBITS_GE_2048-NEXT: ret
258+
%op = load <32 x double>, <32 x double>* %a
259+
%res = call fast double @llvm.vector.reduce.fadd.v32f64(double %start, <32 x double> %op)
260+
ret double %res
261+
}
262+
23263
;
24264
; FMAXV
25265
;
@@ -456,6 +696,27 @@ define double @fminv_v32f64(<32 x double>* %a) #0 {
456696

457697
attributes #0 = { "target-features"="+sve" }
458698

699+
declare half @llvm.vector.reduce.fadd.v4f16(half, <4 x half>)
700+
declare half @llvm.vector.reduce.fadd.v8f16(half, <8 x half>)
701+
declare half @llvm.vector.reduce.fadd.v16f16(half, <16 x half>)
702+
declare half @llvm.vector.reduce.fadd.v32f16(half, <32 x half>)
703+
declare half @llvm.vector.reduce.fadd.v64f16(half, <64 x half>)
704+
declare half @llvm.vector.reduce.fadd.v128f16(half, <128 x half>)
705+
706+
declare float @llvm.vector.reduce.fadd.v2f32(float, <2 x float>)
707+
declare float @llvm.vector.reduce.fadd.v4f32(float, <4 x float>)
708+
declare float @llvm.vector.reduce.fadd.v8f32(float, <8 x float>)
709+
declare float @llvm.vector.reduce.fadd.v16f32(float, <16 x float>)
710+
declare float @llvm.vector.reduce.fadd.v32f32(float, <32 x float>)
711+
declare float @llvm.vector.reduce.fadd.v64f32(float, <64 x float>)
712+
713+
declare double @llvm.vector.reduce.fadd.v1f64(double, <1 x double>)
714+
declare double @llvm.vector.reduce.fadd.v2f64(double, <2 x double>)
715+
declare double @llvm.vector.reduce.fadd.v4f64(double, <4 x double>)
716+
declare double @llvm.vector.reduce.fadd.v8f64(double, <8 x double>)
717+
declare double @llvm.vector.reduce.fadd.v16f64(double, <16 x double>)
718+
declare double @llvm.vector.reduce.fadd.v32f64(double, <32 x double>)
719+
459720
declare half @llvm.vector.reduce.fmax.v4f16(<4 x half>)
460721
declare half @llvm.vector.reduce.fmax.v8f16(<8 x half>)
461722
declare half @llvm.vector.reduce.fmax.v16f16(<16 x half>)

0 commit comments

Comments
 (0)