-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64] Add fixed-length SVE USDOT support #143730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesFull diff: https://github.com/llvm/llvm-project/pull/143730.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 766599d567efd..03f381f9d7a93 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2272,6 +2272,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setPartialReduceMLAAction(MLAOps, VT,
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
}
+
+ if (Subtarget->hasMatMulInt8()) {
+ if (VT.getVectorElementType() == MVT::i32)
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+ else if (VT.getVectorElementType() == MVT::i64)
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+ }
}
// Lower fixed length vector operations to scalable equivalents.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index 79d766d1b9908..81ed3e73481f8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
-; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
target triple = "aarch64"
@@ -407,6 +407,46 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
ret <4 x i32> %partial.reduce
}
+define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldr q0, [x0]
+; COMMON-NEXT: ldr q1, [x1]
+; COMMON-NEXT: ldr q2, [x2]
+; COMMON-NEXT: usdot v0.4s, v1.16b, v2.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ptrue p0.s, vl4
+; SME-NEXT: ldr q2, [x0]
+; SME-NEXT: mov w8, #4 // =0x4
+; SME-NEXT: ld1b { z0.s }, p0/z, [x1]
+; SME-NEXT: ld1sb { z1.s }, p0/z, [x2]
+; SME-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT: mov w8, #8 // =0x8
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT: mov w8, #12 // =0xc
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: ret
+ %acc = load <4 x i32>, ptr %accptr
+ %u = load <16 x i8>, ptr %uptr
+ %s = load <16 x i8>, ptr %sptr
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = sext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
;
; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +478,67 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
ret <8 x i32> %partial.reduce
}
+define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; COMMON: // %bb.0:
+; COMMON-NEXT: ldp q0, q1, [x0]
+; COMMON-NEXT: ldp q3, q2, [x1]
+; COMMON-NEXT: ldp q5, q4, [x2]
+; COMMON-NEXT: usdot v0.4s, v3.16b, v5.16b
+; COMMON-NEXT: usdot v1.4s, v2.16b, v4.16b
+; COMMON-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ptrue p0.s, vl4
+; SME-NEXT: mov w8, #16 // =0x10
+; SME-NEXT: mov w9, #4 // =0x4
+; SME-NEXT: ldp q5, q4, [x0]
+; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x8]
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT: ld1sb { z3.s }, p0/z, [x2]
+; SME-NEXT: mov w8, #20 // =0x14
+; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT: mad z0.s, p0/m, z2.s, z4.s
+; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT: mad z1.s, p0/m, z3.s, z5.s
+; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT: mov w8, #24 // =0x18
+; SME-NEXT: mov w9, #8 // =0x8
+; SME-NEXT: ld1b { z5.s }, p0/z, [x1, x8]
+; SME-NEXT: mla z0.s, p0/m, z3.s, z6.s
+; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT: mov w8, #28 // =0x1c
+; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT: ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT: mov w9, #12 // =0xc
+; SME-NEXT: ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT: mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT: movprfx z2, z0
+; SME-NEXT: mla z2.s, p0/m, z3.s, z5.s
+; SME-NEXT: ld1b { z0.s }, p0/z, [x1, x9]
+; SME-NEXT: ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT: ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT: mad z0.s, p0/m, z4.s, z1.s
+; SME-NEXT: movprfx z1, z2
+; SME-NEXT: mla z1.s, p0/m, z3.s, z6.s
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = sext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
;
;
@@ -483,6 +584,61 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
ret <8 x i32> %partial.reduce
}
+define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256_usdot:
+; NEON: // %bb.0:
+; NEON-NEXT: ldp q0, q1, [x0]
+; NEON-NEXT: ldp q3, q2, [x1]
+; NEON-NEXT: ldp q5, q4, [x2]
+; NEON-NEXT: usdot v0.4s, v3.16b, v5.16b
+; NEON-NEXT: usdot v1.4s, v2.16b, v4.16b
+; NEON-NEXT: ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256_usdot:
+; SVE: // %bb.0:
+; SVE-NEXT: ldr z0, [x0]
+; SVE-NEXT: ldr z1, [x1]
+; SVE-NEXT: ldr z2, [x2]
+; SVE-NEXT: usdot z0.s, z1.b, z2.b
+; SVE-NEXT: mov z1.d, z0.d
+; SVE-NEXT: ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT: ret
+;
+; SME-LABEL: four_way_i8_i32_vl256_usdot:
+; SME: // %bb.0:
+; SME-NEXT: ptrue p0.s
+; SME-NEXT: ldr z0, [x0]
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2]
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #1, mul vl]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #2, mul vl]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: ld1b { z1.s }, p0/z, [x1, #3, mul vl]
+; SME-NEXT: ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
+; SME-NEXT: mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT: mov z1.d, z0.d
+; SME-NEXT: ext z1.b, z1.b, z0.b, #16
+; SME-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT: // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT: ret
+ %acc = load <8 x i32>, ptr %accptr
+ %u = load <32 x i8>, ptr %uptr
+ %s = load <32 x i8>, ptr %sptr
+ %u.wide = zext <32 x i8> %u to <32 x i32>
+ %s.wide = sext <32 x i8> %s to <32 x i32>
+ %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+ ret <8 x i32> %partial.reduce
+}
+
;
; Four-way dot (i16 -> i64)
;
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good to me. I've just added the request for two more tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
No description provided.