Skip to content

[AArch64] Add Neon USDOT support #143525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);

if (Subtarget->hasMatMulInt8()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do a similar thing for fixed-length SVE, similar to what was done for #142032 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, PR opened at #143730

setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
MVT::v16i8, Legal);
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
MVT::v16i8, Custom);

setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
MVT::v8i8, Legal);
}
}

} else /* !isNeonAvailable */ {
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,11 @@ multiclass SIMDSUDOTIndex {

defm SUDOTlane : SIMDSUDOTIndex;

def : Pat<(v2i32 (partial_reduce_sumla v2i32:$Acc, v8i8:$LHS, v8i8:$RHS)),
(USDOTv8i8 $Acc, $RHS, $LHS)>;
def : Pat<(v4i32 (partial_reduce_sumla v4i32:$Acc, v16i8:$LHS, v16i8:$RHS)),
(USDOTv16i8 $Acc, $RHS, $LHS)>;

}

// ARMv8.2-A FP16 Fused Multiply-Add Long
Expand Down
128 changes: 16 additions & 112 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,7 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
Expand Down Expand Up @@ -247,15 +240,8 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v3.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB6_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: ret
Expand Down Expand Up @@ -306,19 +292,7 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
Expand Down Expand Up @@ -347,14 +321,7 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v3.8h, v1.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%s.wide = sext <16 x i8> %u to <16 x i32>
%u.wide = zext <16 x i8> %s to <16 x i32>
Expand Down Expand Up @@ -413,15 +380,8 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB9_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: ret
Expand Down Expand Up @@ -472,19 +432,7 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
Expand Down Expand Up @@ -614,26 +562,10 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_8to64:
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.4s, v4.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v7.4s, v2.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v16.4s, v5.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v17.4s, v3.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.4s, v4.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v5.4s, v5.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.4s, v3.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
Expand Down Expand Up @@ -679,26 +611,10 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_8to64:
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v6.4s, v4.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.4s, v2.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v16.4s, v5.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v17.4s, v3.4h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v4.4s, v4.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v5.4s, v5.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
Expand Down Expand Up @@ -1147,21 +1063,9 @@ define <4 x i32> @usdot_multiple_zext_users(ptr %p1, ptr %p2, ptr %p3) {
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q4, [x2, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v2.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.8h, v4.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.8h, v3.8b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.8h, v4.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v4.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v4.16b, v3.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #1024
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v6.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v7.4h, v6.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v5.8h, v6.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v7.8h, v6.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v4.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v3.4h, v4.4h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v4.8h
; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v3.8h, v4.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB28_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.4s, v1.4s, v0.4s
Expand Down
Loading