Skip to content

Commit 1a80828

Browse files
committed
[AArch64] Extend vecreduce -> udot handling to mla reductions
We previously have lowering for: vecreduce.add(zext(X)) to vecreduce.add(UDOT(zero, X, one)) This extends that to also handle: vecreduce.add(mul(zext(X), zext(Y)) to vecreduce.add(UDOT(zero, X, Y)) It extends the existing code to optionally handle a mul with equal extends. Differential Revision: https://reviews.llvm.org/D97280
1 parent d75c9e6 commit 1a80828

File tree

2 files changed

+56
-103
lines changed

2 files changed

+56
-103
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11747,31 +11747,46 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
1174711747

1174811748
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
1174911749
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
11750+
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
1175011751
static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1175111752
const AArch64Subtarget *ST) {
1175211753
SDValue Op0 = N->getOperand(0);
11753-
if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32)
11754-
return SDValue();
11755-
11756-
if (Op0.getValueType().getVectorElementType() != MVT::i32)
11754+
if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 ||
11755+
Op0.getValueType().getVectorElementType() != MVT::i32)
1175711756
return SDValue();
1175811757

1175911758
unsigned ExtOpcode = Op0.getOpcode();
11759+
SDValue A = Op0;
11760+
SDValue B;
11761+
if (ExtOpcode == ISD::MUL) {
11762+
A = Op0.getOperand(0);
11763+
B = Op0.getOperand(1);
11764+
if (A.getOpcode() != B.getOpcode() ||
11765+
A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
11766+
return SDValue();
11767+
ExtOpcode = A.getOpcode();
11768+
}
1176011769
if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
1176111770
return SDValue();
1176211771

11763-
EVT Op0VT = Op0.getOperand(0).getValueType();
11772+
EVT Op0VT = A.getOperand(0).getValueType();
1176411773
if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
1176511774
return SDValue();
1176611775

1176711776
SDLoc DL(Op0);
11768-
SDValue Ones = DAG.getConstant(1, DL, Op0VT);
11777+
// For non-mla reductions B can be set to 1. For MLA we take the operand of
11778+
// the extend B.
11779+
if (!B)
11780+
B = DAG.getConstant(1, DL, Op0VT);
11781+
else
11782+
B = B.getOperand(0);
11783+
1176911784
SDValue Zeros =
1177011785
DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
1177111786
auto DotOpcode =
1177211787
(ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
1177311788
SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
11774-
Ones, Op0.getOperand(0));
11789+
A.getOperand(0), B);
1177511790
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
1177611791
}
1177711792

llvm/test/CodeGen/AArch64/neon-dotreduce.ll

Lines changed: 34 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@ define i32 @test_udot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b)
99
; CHECK: // %bb.0: // %entry
1010
; CHECK-NEXT: ldr d0, [x0]
1111
; CHECK-NEXT: ldr d1, [x1]
12-
; CHECK-NEXT: dup v2.2s, wzr
12+
; CHECK-NEXT: movi v2.2d, #0000000000000000
1313
; CHECK-NEXT: udot v2.2s, v1.8b, v0.8b
1414
; CHECK-NEXT: addp v0.2s, v2.2s, v2.2s
15-
; CHECK-NEXT: fmov x0, d0
16-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
15+
; CHECK-NEXT: fmov w0, s0
1716
; CHECK-NEXT: ret
1817
entry:
1918
%0 = bitcast i8* %a to <8 x i8>*
@@ -33,7 +32,7 @@ define i32 @test_udot_v8i8_nomla(i8* nocapture readonly %a1) {
3332
; CHECK-NEXT: ldr d0, [x0]
3433
; CHECK-NEXT: movi v1.2d, #0000000000000000
3534
; CHECK-NEXT: movi v2.8b, #1
36-
; CHECK-NEXT: udot v1.2s, v2.8b, v0.8b
35+
; CHECK-NEXT: udot v1.2s, v0.8b, v2.8b
3736
; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s
3837
; CHECK-NEXT: fmov w0, s0
3938
; CHECK-NEXT: ret
@@ -50,11 +49,10 @@ define i32 @test_sdot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b)
5049
; CHECK: // %bb.0: // %entry
5150
; CHECK-NEXT: ldr d0, [x0]
5251
; CHECK-NEXT: ldr d1, [x1]
53-
; CHECK-NEXT: dup v2.2s, wzr
52+
; CHECK-NEXT: movi v2.2d, #0000000000000000
5453
; CHECK-NEXT: sdot v2.2s, v1.8b, v0.8b
5554
; CHECK-NEXT: addp v0.2s, v2.2s, v2.2s
56-
; CHECK-NEXT: fmov x0, d0
57-
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
55+
; CHECK-NEXT: fmov w0, s0
5856
; CHECK-NEXT: ret
5957
entry:
6058
%0 = bitcast i8* %a to <8 x i8>*
@@ -74,7 +72,7 @@ define i32 @test_sdot_v8i8_nomla(i8* nocapture readonly %a1) {
7472
; CHECK-NEXT: ldr d0, [x0]
7573
; CHECK-NEXT: movi v1.2d, #0000000000000000
7674
; CHECK-NEXT: movi v2.8b, #1
77-
; CHECK-NEXT: sdot v1.2s, v2.8b, v0.8b
75+
; CHECK-NEXT: sdot v1.2s, v0.8b, v2.8b
7876
; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s
7977
; CHECK-NEXT: fmov w0, s0
8078
; CHECK-NEXT: ret
@@ -92,7 +90,7 @@ define i32 @test_udot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b
9290
; CHECK: // %bb.0: // %entry
9391
; CHECK-NEXT: ldr q0, [x0]
9492
; CHECK-NEXT: ldr q1, [x1]
95-
; CHECK-NEXT: dup v2.4s, wzr
93+
; CHECK-NEXT: movi v2.2d, #0000000000000000
9694
; CHECK-NEXT: udot v2.4s, v1.16b, v0.16b
9795
; CHECK-NEXT: addv s0, v2.4s
9896
; CHECK-NEXT: fmov w8, s0
@@ -117,7 +115,7 @@ define i32 @test_udot_v16i8_nomla(i8* nocapture readonly %a1) {
117115
; CHECK-NEXT: ldr q0, [x0]
118116
; CHECK-NEXT: movi v1.16b, #1
119117
; CHECK-NEXT: movi v2.2d, #0000000000000000
120-
; CHECK-NEXT: udot v2.4s, v1.16b, v0.16b
118+
; CHECK-NEXT: udot v2.4s, v0.16b, v1.16b
121119
; CHECK-NEXT: addv s0, v2.4s
122120
; CHECK-NEXT: fmov w0, s0
123121
; CHECK-NEXT: ret
@@ -134,7 +132,7 @@ define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b
134132
; CHECK: // %bb.0: // %entry
135133
; CHECK-NEXT: ldr q0, [x0]
136134
; CHECK-NEXT: ldr q1, [x1]
137-
; CHECK-NEXT: dup v2.4s, wzr
135+
; CHECK-NEXT: movi v2.2d, #0000000000000000
138136
; CHECK-NEXT: sdot v2.4s, v1.16b, v0.16b
139137
; CHECK-NEXT: addv s0, v2.4s
140138
; CHECK-NEXT: fmov w8, s0
@@ -159,7 +157,7 @@ define i32 @test_sdot_v16i8_nomla(i8* nocapture readonly %a1) {
159157
; CHECK-NEXT: ldr q0, [x0]
160158
; CHECK-NEXT: movi v1.16b, #1
161159
; CHECK-NEXT: movi v2.2d, #0000000000000000
162-
; CHECK-NEXT: sdot v2.4s, v1.16b, v0.16b
160+
; CHECK-NEXT: sdot v2.4s, v0.16b, v1.16b
163161
; CHECK-NEXT: addv s0, v2.4s
164162
; CHECK-NEXT: fmov w0, s0
165163
; CHECK-NEXT: ret
@@ -175,20 +173,10 @@ entry:
175173
define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
176174
; CHECK-LABEL: test_udot_v8i8_double:
177175
; CHECK: // %bb.0: // %entry
178-
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
179-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
180-
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
181-
; CHECK-NEXT: ushll v3.8h, v3.8b, #0
182-
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
183-
; CHECK-NEXT: ext v5.16b, v1.16b, v1.16b, #8
184-
; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h
185-
; CHECK-NEXT: ext v1.16b, v2.16b, v2.16b, #8
186-
; CHECK-NEXT: umull v2.4s, v2.4h, v3.4h
187-
; CHECK-NEXT: ext v3.16b, v3.16b, v3.16b, #8
188-
; CHECK-NEXT: umlal v0.4s, v4.4h, v5.4h
189-
; CHECK-NEXT: umlal v2.4s, v1.4h, v3.4h
190-
; CHECK-NEXT: add v0.4s, v0.4s, v2.4s
191-
; CHECK-NEXT: addv s0, v0.4s
176+
; CHECK-NEXT: movi v4.2d, #0000000000000000
177+
; CHECK-NEXT: udot v4.2s, v2.8b, v3.8b
178+
; CHECK-NEXT: udot v4.2s, v0.8b, v1.8b
179+
; CHECK-NEXT: addp v0.2s, v4.2s, v4.2s
192180
; CHECK-NEXT: fmov w0, s0
193181
; CHECK-NEXT: ret
194182
entry:
@@ -209,8 +197,8 @@ define i32 @test_udot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <
209197
; CHECK: // %bb.0: // %entry
210198
; CHECK-NEXT: movi v1.2d, #0000000000000000
211199
; CHECK-NEXT: movi v3.8b, #1
212-
; CHECK-NEXT: udot v1.2s, v3.8b, v2.8b
213-
; CHECK-NEXT: udot v1.2s, v3.8b, v0.8b
200+
; CHECK-NEXT: udot v1.2s, v2.8b, v3.8b
201+
; CHECK-NEXT: udot v1.2s, v0.8b, v3.8b
214202
; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s
215203
; CHECK-NEXT: fmov w0, s0
216204
; CHECK-NEXT: ret
@@ -226,30 +214,10 @@ entry:
226214
define i32 @test_udot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
227215
; CHECK-LABEL: test_udot_v16i8_double:
228216
; CHECK: // %bb.0: // %entry
229-
; CHECK-NEXT: ushll2 v4.8h, v0.16b, #0
230-
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
231-
; CHECK-NEXT: ushll2 v5.8h, v1.16b, #0
232-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
233-
; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
234-
; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8
235-
; CHECK-NEXT: umull2 v16.4s, v0.8h, v1.8h
236-
; CHECK-NEXT: umlal v16.4s, v6.4h, v7.4h
237-
; CHECK-NEXT: ushll2 v6.8h, v2.16b, #0
238-
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
239-
; CHECK-NEXT: ushll2 v7.8h, v3.16b, #0
240-
; CHECK-NEXT: ushll v3.8h, v3.8b, #0
241-
; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h
242-
; CHECK-NEXT: ext v1.16b, v6.16b, v6.16b, #8
243-
; CHECK-NEXT: umlal v0.4s, v4.4h, v5.4h
244-
; CHECK-NEXT: ext v4.16b, v7.16b, v7.16b, #8
245-
; CHECK-NEXT: umull v5.4s, v2.4h, v3.4h
246-
; CHECK-NEXT: umull2 v2.4s, v2.8h, v3.8h
247-
; CHECK-NEXT: umlal v2.4s, v1.4h, v4.4h
248-
; CHECK-NEXT: umlal v5.4s, v6.4h, v7.4h
249-
; CHECK-NEXT: add v0.4s, v0.4s, v16.4s
250-
; CHECK-NEXT: add v1.4s, v5.4s, v2.4s
251-
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
252-
; CHECK-NEXT: addv s0, v0.4s
217+
; CHECK-NEXT: movi v4.2d, #0000000000000000
218+
; CHECK-NEXT: udot v4.4s, v2.16b, v3.16b
219+
; CHECK-NEXT: udot v4.4s, v0.16b, v1.16b
220+
; CHECK-NEXT: addv s0, v4.4s
253221
; CHECK-NEXT: fmov w0, s0
254222
; CHECK-NEXT: ret
255223
entry:
@@ -270,8 +238,8 @@ define i32 @test_udot_v16i8_double_nomla(<16 x i8> %a, <16 x i8> %b, <16 x i8> %
270238
; CHECK: // %bb.0: // %entry
271239
; CHECK-NEXT: movi v1.16b, #1
272240
; CHECK-NEXT: movi v3.2d, #0000000000000000
273-
; CHECK-NEXT: udot v3.4s, v1.16b, v2.16b
274-
; CHECK-NEXT: udot v3.4s, v1.16b, v0.16b
241+
; CHECK-NEXT: udot v3.4s, v2.16b, v1.16b
242+
; CHECK-NEXT: udot v3.4s, v0.16b, v1.16b
275243
; CHECK-NEXT: addv s0, v3.4s
276244
; CHECK-NEXT: fmov w0, s0
277245
; CHECK-NEXT: ret
@@ -287,20 +255,10 @@ entry:
287255
define i32 @test_sdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
288256
; CHECK-LABEL: test_sdot_v8i8_double:
289257
; CHECK: // %bb.0: // %entry
290-
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
291-
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
292-
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
293-
; CHECK-NEXT: sshll v3.8h, v3.8b, #0
294-
; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
295-
; CHECK-NEXT: ext v5.16b, v1.16b, v1.16b, #8
296-
; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h
297-
; CHECK-NEXT: ext v1.16b, v2.16b, v2.16b, #8
298-
; CHECK-NEXT: smull v2.4s, v2.4h, v3.4h
299-
; CHECK-NEXT: ext v3.16b, v3.16b, v3.16b, #8
300-
; CHECK-NEXT: smlal v0.4s, v4.4h, v5.4h
301-
; CHECK-NEXT: smlal v2.4s, v1.4h, v3.4h
302-
; CHECK-NEXT: add v0.4s, v0.4s, v2.4s
303-
; CHECK-NEXT: addv s0, v0.4s
258+
; CHECK-NEXT: movi v4.2d, #0000000000000000
259+
; CHECK-NEXT: sdot v4.2s, v2.8b, v3.8b
260+
; CHECK-NEXT: sdot v4.2s, v0.8b, v1.8b
261+
; CHECK-NEXT: addp v0.2s, v4.2s, v4.2s
304262
; CHECK-NEXT: fmov w0, s0
305263
; CHECK-NEXT: ret
306264
entry:
@@ -321,8 +279,8 @@ define i32 @test_sdot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <
321279
; CHECK: // %bb.0: // %entry
322280
; CHECK-NEXT: movi v1.2d, #0000000000000000
323281
; CHECK-NEXT: movi v3.8b, #1
324-
; CHECK-NEXT: sdot v1.2s, v3.8b, v2.8b
325-
; CHECK-NEXT: sdot v1.2s, v3.8b, v0.8b
282+
; CHECK-NEXT: sdot v1.2s, v2.8b, v3.8b
283+
; CHECK-NEXT: sdot v1.2s, v0.8b, v3.8b
326284
; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s
327285
; CHECK-NEXT: fmov w0, s0
328286
; CHECK-NEXT: ret
@@ -338,30 +296,10 @@ entry:
338296
define i32 @test_sdot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
339297
; CHECK-LABEL: test_sdot_v16i8_double:
340298
; CHECK: // %bb.0: // %entry
341-
; CHECK-NEXT: sshll2 v4.8h, v0.16b, #0
342-
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
343-
; CHECK-NEXT: sshll2 v5.8h, v1.16b, #0
344-
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
345-
; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
346-
; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8
347-
; CHECK-NEXT: smull2 v16.4s, v0.8h, v1.8h
348-
; CHECK-NEXT: smlal v16.4s, v6.4h, v7.4h
349-
; CHECK-NEXT: sshll2 v6.8h, v2.16b, #0
350-
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
351-
; CHECK-NEXT: sshll2 v7.8h, v3.16b, #0
352-
; CHECK-NEXT: sshll v3.8h, v3.8b, #0
353-
; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h
354-
; CHECK-NEXT: ext v1.16b, v6.16b, v6.16b, #8
355-
; CHECK-NEXT: smlal v0.4s, v4.4h, v5.4h
356-
; CHECK-NEXT: ext v4.16b, v7.16b, v7.16b, #8
357-
; CHECK-NEXT: smull v5.4s, v2.4h, v3.4h
358-
; CHECK-NEXT: smull2 v2.4s, v2.8h, v3.8h
359-
; CHECK-NEXT: smlal v2.4s, v1.4h, v4.4h
360-
; CHECK-NEXT: smlal v5.4s, v6.4h, v7.4h
361-
; CHECK-NEXT: add v0.4s, v0.4s, v16.4s
362-
; CHECK-NEXT: add v1.4s, v5.4s, v2.4s
363-
; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
364-
; CHECK-NEXT: addv s0, v0.4s
299+
; CHECK-NEXT: movi v4.2d, #0000000000000000
300+
; CHECK-NEXT: sdot v4.4s, v2.16b, v3.16b
301+
; CHECK-NEXT: sdot v4.4s, v0.16b, v1.16b
302+
; CHECK-NEXT: addv s0, v4.4s
365303
; CHECK-NEXT: fmov w0, s0
366304
; CHECK-NEXT: ret
367305
entry:
@@ -382,8 +320,8 @@ define i32 @test_sdot_v16i8_double_nomla(<16 x i8> %a, <16 x i8> %b, <16 x i8> %
382320
; CHECK: // %bb.0: // %entry
383321
; CHECK-NEXT: movi v1.16b, #1
384322
; CHECK-NEXT: movi v3.2d, #0000000000000000
385-
; CHECK-NEXT: sdot v3.4s, v1.16b, v2.16b
386-
; CHECK-NEXT: sdot v3.4s, v1.16b, v0.16b
323+
; CHECK-NEXT: sdot v3.4s, v2.16b, v1.16b
324+
; CHECK-NEXT: sdot v3.4s, v0.16b, v1.16b
387325
; CHECK-NEXT: addv s0, v3.4s
388326
; CHECK-NEXT: fmov w0, s0
389327
; CHECK-NEXT: ret

0 commit comments

Comments
 (0)