Skip to content

Commit 8762ad8

Browse files
NickGuy-Armtomtor
authored andcommitted
[AArch64] Add AArch64 SVE lowering for usdot (llvm#143403)
1 parent a703b23 commit 8762ad8

File tree

4 files changed

+68
-115
lines changed

4 files changed

+68
-115
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,8 @@ def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
521521
SDTPartialReduceMLA>;
522522
def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
523523
SDTPartialReduceMLA>;
524+
def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
525+
SDTPartialReduceMLA>;
524526

525527
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
526528
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18951895

18961896
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
18971897

1898+
if (Subtarget->hasMatMulInt8()) {
1899+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv4i32,
1900+
MVT::nxv16i8, Legal);
1901+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::nxv2i64,
1902+
MVT::nxv16i8, Custom);
1903+
}
1904+
18981905
// Wide add types
18991906
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
19001907
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
@@ -7516,6 +7523,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
75167523
return LowerVECTOR_HISTOGRAM(Op, DAG);
75177524
case ISD::PARTIAL_REDUCE_SMLA:
75187525
case ISD::PARTIAL_REDUCE_UMLA:
7526+
case ISD::PARTIAL_REDUCE_SUMLA:
75197527
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
75207528
}
75217529
}

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,6 +4116,11 @@ let Predicates = [HasSVEAES2, HasNonStreamingSVE2_or_SSVE_AES] in {
41164116
def PMULL_2ZZZ_Q : sve_crypto_pmull_multi<"pmull">;
41174117
}
41184118

4119+
let Predicates = [HasSVE_or_SME, HasMatMulInt8] in {
4120+
def : Pat<(nxv4i32 (partial_reduce_sumla nxv4i32:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
4121+
(USDOT_ZZZ $Acc, $RHS, $LHS)>;
4122+
} // End HasSVE_or_SME, HasMatMulInt8
4123+
41194124
//===----------------------------------------------------------------------===//
41204125
// SME or SVE2.1 instructions
41214126
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 53 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
44
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
55
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
6-
; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
6+
; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
77

88
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
99
; CHECK-LABEL: udot:
@@ -106,23 +106,7 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
106106
;
107107
; CHECK-NEWLOWERING-LABEL: usdot:
108108
; CHECK-NEWLOWERING: // %bb.0: // %entry
109-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
110-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
111-
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
112-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
113-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
114-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
115-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
116-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
117-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
118-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
119-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z1.h
120-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z2.h
121-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
122-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
123-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
124-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
125-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
109+
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z1.b, z2.b
126110
; CHECK-NEWLOWERING-NEXT: ret
127111
entry:
128112
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -161,23 +145,7 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
161145
;
162146
; CHECK-NEWLOWERING-LABEL: sudot:
163147
; CHECK-NEWLOWERING: // %bb.0: // %entry
164-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
165-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
166-
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
167-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
168-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
169-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
170-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
171-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
172-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
173-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
174-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z1.h
175-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z2.h
176-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
177-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
178-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
179-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
180-
; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
148+
; CHECK-NEWLOWERING-NEXT: usdot z0.s, z2.b, z1.b
181149
; CHECK-NEWLOWERING-NEXT: ret
182150
entry:
183151
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
@@ -329,46 +297,31 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
329297
; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z3.d
330298
; CHECK-NOI8MM-NEXT: ret
331299
;
332-
; CHECK-NEWLOWERING-LABEL: usdot_8to64:
333-
; CHECK-NEWLOWERING: // %bb.0: // %entry
334-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
335-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
336-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
337-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
338-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
339-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
340-
; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
341-
; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
342-
; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
343-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
344-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
345-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
346-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
347-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
348-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
349-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
350-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
351-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
352-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
353-
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
354-
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
355-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
356-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
357-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
358-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
359-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
360-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
361-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
362-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
363-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
364-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
365-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
366-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
367-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
368-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
369-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
370-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
371-
; CHECK-NEWLOWERING-NEXT: ret
300+
; CHECK-NEWLOWERING-SVE-LABEL: usdot_8to64:
301+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
302+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
303+
; CHECK-NEWLOWERING-SVE-NEXT: usdot z4.s, z2.b, z3.b
304+
; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z4.s
305+
; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z4.s
306+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
307+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z3.d
308+
; CHECK-NEWLOWERING-SVE-NEXT: ret
309+
;
310+
; CHECK-NEWLOWERING-SVE2-LABEL: usdot_8to64:
311+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
312+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
313+
; CHECK-NEWLOWERING-SVE2-NEXT: usdot z4.s, z2.b, z3.b
314+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
315+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
316+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
317+
;
318+
; CHECK-NEWLOWERING-SME-LABEL: usdot_8to64:
319+
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
320+
; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
321+
; CHECK-NEWLOWERING-SME-NEXT: usdot z4.s, z2.b, z3.b
322+
; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
323+
; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
324+
; CHECK-NEWLOWERING-SME-NEXT: ret
372325
entry:
373326
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
374327
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -430,46 +383,31 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
430383
; CHECK-NOI8MM-NEXT: mla z0.d, p0/m, z2.d, z3.d
431384
; CHECK-NOI8MM-NEXT: ret
432385
;
433-
; CHECK-NEWLOWERING-LABEL: sudot_8to64:
434-
; CHECK-NEWLOWERING: // %bb.0: // %entry
435-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
436-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
437-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
438-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
439-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
440-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
441-
; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
442-
; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
443-
; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
444-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
445-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
446-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
447-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
448-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
449-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
450-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
451-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
452-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
453-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
454-
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
455-
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
456-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
457-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
458-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
459-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
460-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
461-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
462-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
463-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
464-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
465-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
466-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
467-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
468-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
469-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
470-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
471-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
472-
; CHECK-NEWLOWERING-NEXT: ret
386+
; CHECK-NEWLOWERING-SVE-LABEL: sudot_8to64:
387+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
388+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
389+
; CHECK-NEWLOWERING-SVE-NEXT: usdot z4.s, z3.b, z2.b
390+
; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z4.s
391+
; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z4.s
392+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
393+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z3.d
394+
; CHECK-NEWLOWERING-SVE-NEXT: ret
395+
;
396+
; CHECK-NEWLOWERING-SVE2-LABEL: sudot_8to64:
397+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
398+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
399+
; CHECK-NEWLOWERING-SVE2-NEXT: usdot z4.s, z3.b, z2.b
400+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
401+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
402+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
403+
;
404+
; CHECK-NEWLOWERING-SME-LABEL: sudot_8to64:
405+
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
406+
; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
407+
; CHECK-NEWLOWERING-SME-NEXT: usdot z4.s, z3.b, z2.b
408+
; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
409+
; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
410+
; CHECK-NEWLOWERING-SME-NEXT: ret
473411
entry:
474412
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
475413
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>

0 commit comments

Comments
 (0)