-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64] Add Neon USDOT support #143525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Add Neon USDOT support #143525
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesFull diff: https://github.com/llvm/llvm-project/pull/143525.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index caac00c5b2faa..766599d567efd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1464,6 +1464,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
+
+ if (Subtarget->hasMatMulInt8()) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v4i32,
+ MVT::v16i8, Legal);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i64,
+ MVT::v16i8, Custom);
+
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, MVT::v2i32,
+ MVT::v8i8, Legal);
+ }
}
} else /* !isNeonAvailable */ {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f5b66b75eb407..f90f12b5ac3c7 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1711,6 +1711,11 @@ multiclass SIMDSUDOTIndex {
defm SUDOTlane : SIMDSUDOTIndex;
+def : Pat<(v2i32 (partial_reduce_sumla v2i32:$Acc, v8i8:$LHS, v8i8:$RHS)),
+ (USDOTv8i8 $Acc, $RHS, $LHS)>;
+def : Pat<(v4i32 (partial_reduce_sumla v4i32:$Acc, v16i8:$LHS, v16i8:$RHS)),
+ (USDOTv16i8 $Acc, $RHS, $LHS)>;
+
}
// ARMv8.2-A FP16 Fused Multiply-Add Long
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d977d8fc9cf21..0c7b3c7d3c138 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -181,14 +181,7 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -247,15 +240,8 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v3.16b, v2.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB6_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: ret
@@ -306,19 +292,7 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -347,14 +321,7 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%s.wide = sext <16 x i8> %u to <16 x i32>
%u.wide = zext <16 x i8> %s to <16 x i32>
@@ -413,15 +380,8 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v2.16b, v3.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v4.4h, v5.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v4.8h, v5.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v2.4h, v3.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v2.8h, v3.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB9_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: ret
@@ -472,19 +432,7 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
; CHECK-NEWLOWERING-I8MM: // %bb.0:
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-I8MM-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v5.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: smull2 v1.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: ext v1.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
; CHECK-NEWLOWERING-I8MM-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -614,26 +562,10 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
;
; CHECK-NEWLOWERING-I8MM-LABEL: usdot_8to64:
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v7.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v16.4s, v5.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v17.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
@@ -679,26 +611,10 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
;
; CHECK-NEWLOWERING-I8MM-LABEL: sudot_8to64:
; CHECK-NEWLOWERING-I8MM: // %bb.0: // %entry
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.4s, v2.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v16.4s, v5.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v17.4s, v3.4h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v6.2s, v16.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v7.2s, v17.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v6.4s, v16.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v7.4s, v17.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-I8MM-NEXT: movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT: saddw2 v0.2d, v0.2d, v4.4s
; CHECK-NEWLOWERING-I8MM-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
@@ -1147,21 +1063,9 @@ define <4 x i32> @usdot_multiple_zext_users(ptr %p1, ptr %p2, ptr %p3) {
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q3, [x1, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: ldr q4, [x2, x8]
; CHECK-NEWLOWERING-I8MM-NEXT: add x8, x8, #16
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v5.8h, v2.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll v6.8h, v4.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll v7.8h, v3.8b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: ushll2 v4.8h, v4.16b, #0
-; CHECK-NEWLOWERING-I8MM-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v0.4s, v4.16b, v2.16b
+; CHECK-NEWLOWERING-I8MM-NEXT: usdot v1.4s, v4.16b, v3.16b
; CHECK-NEWLOWERING-I8MM-NEXT: cmp x8, #1024
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v5.4h, v6.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v7.4h, v6.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v5.8h, v6.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v7.8h, v6.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v0.4s, v2.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal v1.4s, v3.4h, v4.4h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v0.4s, v2.8h, v4.8h
-; CHECK-NEWLOWERING-I8MM-NEXT: smlal2 v1.4s, v3.8h, v4.8h
; CHECK-NEWLOWERING-I8MM-NEXT: b.ne .LBB28_1
; CHECK-NEWLOWERING-I8MM-NEXT: // %bb.2: // %end
; CHECK-NEWLOWERING-I8MM-NEXT: add v0.4s, v1.4s, v0.4s
|
@@ -1464,6 +1464,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, | |||
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal); | |||
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal); | |||
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom); | |||
|
|||
if (Subtarget->hasMatMulInt8()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do a similar thing for fixed-length SVE, similar to what was done for #142032 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, PR opened at #143730
No description provided.