Skip to content

Commit 7c946e6

Browse files
authored
[AArch64] Add Neon USDOT support (#143525)
1 parent 9630d7c commit 7c946e6

File tree

3 files changed

+31
-112
lines changed

3 files changed

+31
-112
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14641464
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
14651465
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
14661466
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1467+
1468+
if (Subtarget->hasMatMulInt8()) {
1469+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
1470+
MVT::v16i8, Legal);
1471+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
1472+
MVT::v16i8, Custom);
1473+
1474+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
1475+
MVT::v8i8, Legal);
1476+
}
14671477
}
14681478

14691479
} else /* !isNeonAvailable */ {

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,11 @@ multiclass SIMDSUDOTIndex {
17111711

17121712
defm SUDOTlane : SIMDSUDOTIndex;
17131713

1714+
def : Pat<(v2i32 (partial_reduce_sumla v2i32:$Acc, v8i8:$LHS, v8i8:$RHS)),
1715+
(USDOTv8i8 $Acc, $RHS, $LHS)>;
1716+
def : Pat<(v4i32 (partial_reduce_sumla v4i32:$Acc, v16i8:$LHS, v16i8:$RHS)),
1717+
(USDOTv16i8 $Acc, $RHS, $LHS)>;
1718+
17141719
}
17151720

17161721
// ARMv8.2-A FP16 Fused Multiply-Add Long

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 16 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,7 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
181181
;
182182
; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
183183
; CHECK-NEWLOWERING-I8MM: // %bb.0:
184-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v3.8h, v1.8b, #0
185-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
186-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v1.8h, v1.16b, #0
187-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
188-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
189-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
190-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
191-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
184+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
192185
; CHECK-NEWLOWERING-I8MM-NEXT: ret
193186
%u.wide = zext <16 x i8> %u to <16 x i32>
194187
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -247,15 +240,8 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
247240
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
248241
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
249242
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
250-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
251-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
252-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
253-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
243+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v3.16b, v2.16b
254244
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
255-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
256-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
257-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
258-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
259245
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB6_1
260246
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
261247
; CHECK-NEWLOWERING-I8MM-NEXT: ret
@@ -306,19 +292,7 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
306292
;
307293
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
308294
; CHECK-NEWLOWERING-I8MM: // %bb.0:
309-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v1.8h, v1.8b, #0
310-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v2.8h, v2.8b, #0
311-
; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
312-
; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
313-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
314-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
315-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
316-
; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
317-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
318-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
319-
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
320-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
321-
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
295+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
322296
; CHECK-NEWLOWERING-I8MM-NEXT: ret
323297
%u.wide = zext <8 x i8> %u to <8 x i32>
324298
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -347,14 +321,7 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
347321
;
348322
; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
349323
; CHECK-NEWLOWERING-I8MM: // %bb.0:
350-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v3.8h, v1.8b, #0
351-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
352-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v1.8h, v1.16b, #0
353-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
354-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
355-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
356-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
357-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
324+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
358325
; CHECK-NEWLOWERING-I8MM-NEXT: ret
359326
%s.wide = sext <16 x i8> %u to <16 x i32>
360327
%u.wide = zext <16 x i8> %s to <16 x i32>
@@ -413,15 +380,8 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
413380
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
414381
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
415382
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
416-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
417-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
418-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
419-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
383+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
420384
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
421-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
422-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
423-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
424-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
425385
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB9_1
426386
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
427387
; CHECK-NEWLOWERING-I8MM-NEXT: ret
@@ -472,19 +432,7 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
472432
;
473433
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
474434
; CHECK-NEWLOWERING-I8MM: // %bb.0:
475-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v1.8h, v1.8b, #0
476-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v2.8h, v2.8b, #0
477-
; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
478-
; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
479-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
480-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
481-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
482-
; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
483-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
484-
; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
485-
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
486-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
487-
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
435+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
488436
; CHECK-NEWLOWERING-I8MM-NEXT: ret
489437
%u.wide = sext <8 x i8> %u to <8 x i32>
490438
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -614,26 +562,10 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
614562
;
615563
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_8to64:
616564
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
617-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
618-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
619-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
620-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
621-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.4s, v4.4h, #0
622-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v7.4s, v2.4h, #0
623-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v16.4s, v5.4h, #0
624-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v17.4s, v3.4h, #0
625-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.4s, v4.8h, #0
626-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.4s, v2.8h, #0
627-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v5.4s, v5.8h, #0
628-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.4s, v3.8h, #0
629-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
630-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
631-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
632-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
633-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
634-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
635-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
636-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
565+
; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
566+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
567+
; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
568+
; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
637569
; CHECK-NEWLOWERING-I8MM-NEXT: ret
638570
entry:
639571
%a.wide = zext <16 x i8> %a to <16 x i64>
@@ -679,26 +611,10 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
679611
;
680612
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_8to64:
681613
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
682-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
683-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
684-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
685-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
686-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v6.4s, v4.4h, #0
687-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.4s, v2.4h, #0
688-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v16.4s, v5.4h, #0
689-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v17.4s, v3.4h, #0
690-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v4.4s, v4.8h, #0
691-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.4s, v2.8h, #0
692-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v5.4s, v5.8h, #0
693-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.4s, v3.8h, #0
694-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
695-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
696-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
697-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
698-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
699-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
700-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
701-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
614+
; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
615+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
616+
; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
617+
; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
702618
; CHECK-NEWLOWERING-I8MM-NEXT: ret
703619
entry:
704620
%a.wide = sext <16 x i8> %a to <16 x i64>
@@ -1147,21 +1063,9 @@ define <4 x i32> @usdot_multiple_zext_users(ptr %p1, ptr %p2, ptr %p3) {
11471063
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
11481064
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q4, [x2, x8]
11491065
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
1150-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v2.8b, #0
1151-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.8h, v4.8b, #0
1152-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.8h, v3.8b, #0
1153-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
1154-
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.8h, v4.16b, #0
1155-
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
1066+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v4.16b, v2.16b
1067+
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v4.16b, v3.16b
11561068
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #1024
1157-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v6.4h
1158-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v7.4h, v6.4h
1159-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v5.8h, v6.8h
1160-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v7.8h, v6.8h
1161-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v4.4h
1162-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v3.4h, v4.4h
1163-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v4.8h
1164-
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v3.8h, v4.8h
11651069
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB28_1
11661070
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
11671071
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.4s, v1.4s, v0.4s

0 commit comments

Comments
 (0)