Skip to content

[AArch64] Add fixed-length SVE USDOT support #143730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2025

Conversation

NickGuy-Arm
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/143730.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+7)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll (+158-2)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 766599d567efd..03f381f9d7a93 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2272,6 +2272,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
       setPartialReduceMLAAction(MLAOps, VT,
                                 MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
     }
+
+    if (Subtarget->hasMatMulInt8()) {
+      if (VT.getVectorElementType() == MVT::i32)
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
+      else if (VT.getVectorElementType() == MVT::i64)
+        setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT, MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
+    }
   }
 
   // Lower fixed length vector operations to scalable equivalents.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
index 79d766d1b9908..81ed3e73481f8 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-partial-reduce.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
-; RUN: llc -mattr=+sve,+dotprod -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
+; RUN: llc -mattr=+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,NEON
+; RUN: llc -mattr=+sve,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes=true < %s | FileCheck %s --check-prefixes=COMMON,SVE
 ; RUN: llc -mattr=+sme -aarch64-enable-partial-reduce-nodes=true -force-streaming < %s | FileCheck %s --check-prefix=SME
 
 target triple = "aarch64"
@@ -407,6 +407,46 @@ define <4 x i32> @four_way_i8_i32_vl128(ptr %accptr, ptr %uptr, ptr %sptr) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @four_way_i8_i32_vl128_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+; COMMON-LABEL: four_way_i8_i32_vl128_usdot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldr q0, [x0]
+; COMMON-NEXT:    ldr q1, [x1]
+; COMMON-NEXT:    ldr q2, [x2]
+; COMMON-NEXT:    usdot v0.4s, v1.16b, v2.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s, vl4
+; SME-NEXT:    ldr q2, [x0]
+; SME-NEXT:    mov w8, #4 // =0x4
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z1.s }, p0/z, [x2]
+; SME-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #8 // =0x8
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #12 // =0xc
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    ret
+  %acc = load <4 x i32>, ptr %accptr
+  %u = load <16 x i8>, ptr %uptr
+  %s = load <16 x i8>, ptr %sptr
+  %u.wide = zext <16 x i8> %u to <16 x i32>
+  %s.wide = sext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %mult)
+  ret <4 x i32> %partial.reduce
+}
+
 define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr %sptr) {
 ;
 ; COMMON-LABEL: four_way_i8_i32_vl128_double_width:
@@ -438,6 +478,67 @@ define <8 x i32> @four_way_i8_i32_vl128_double_width(ptr %accptr, ptr %uptr, ptr
   ret <8 x i32> %partial.reduce
 }
 
+define <8 x i32> @four_way_i8_i32_vl128_double_width_usdot(ptr %accptr, ptr %uptr, ptr %sptr) {
+;
+; COMMON-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; COMMON:       // %bb.0:
+; COMMON-NEXT:    ldp q0, q1, [x0]
+; COMMON-NEXT:    ldp q3, q2, [x1]
+; COMMON-NEXT:    ldp q5, q4, [x2]
+; COMMON-NEXT:    usdot v0.4s, v3.16b, v5.16b
+; COMMON-NEXT:    usdot v1.4s, v2.16b, v4.16b
+; COMMON-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl128_double_width_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s, vl4
+; SME-NEXT:    mov w8, #16 // =0x10
+; SME-NEXT:    mov w9, #4 // =0x4
+; SME-NEXT:    ldp q5, q4, [x0]
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x8]
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, x8]
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2]
+; SME-NEXT:    mov w8, #20 // =0x14
+; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT:    mad z0.s, p0/m, z2.s, z4.s
+; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mad z1.s, p0/m, z3.s, z5.s
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #24 // =0x18
+; SME-NEXT:    mov w9, #8 // =0x8
+; SME-NEXT:    ld1b { z5.s }, p0/z, [x1, x8]
+; SME-NEXT:    mla z0.s, p0/m, z3.s, z6.s
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    mov w8, #28 // =0x1c
+; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT:    ld1b { z2.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mov w9, #12 // =0xc
+; SME-NEXT:    ld1b { z6.s }, p0/z, [x1, x8]
+; SME-NEXT:    mla z1.s, p0/m, z4.s, z2.s
+; SME-NEXT:    movprfx z2, z0
+; SME-NEXT:    mla z2.s, p0/m, z3.s, z5.s
+; SME-NEXT:    ld1b { z0.s }, p0/z, [x1, x9]
+; SME-NEXT:    ld1sb { z3.s }, p0/z, [x2, x8]
+; SME-NEXT:    ld1sb { z4.s }, p0/z, [x2, x9]
+; SME-NEXT:    mad z0.s, p0/m, z4.s, z1.s
+; SME-NEXT:    movprfx z1, z2
+; SME-NEXT:    mla z1.s, p0/m, z3.s, z6.s
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = sext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
 define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
 ;
 ;
@@ -483,6 +584,61 @@ define <8 x i32> @four_way_i8_i32_vl256(ptr %accptr, ptr %uptr, ptr %sptr) vscal
   ret <8 x i32> %partial.reduce
 }
 
+define <8 x i32> @four_way_i8_i32_vl256_usdot(ptr %accptr, ptr %uptr, ptr %sptr) vscale_range(2,2) {
+;
+;
+; NEON-LABEL: four_way_i8_i32_vl256_usdot:
+; NEON:       // %bb.0:
+; NEON-NEXT:    ldp q0, q1, [x0]
+; NEON-NEXT:    ldp q3, q2, [x1]
+; NEON-NEXT:    ldp q5, q4, [x2]
+; NEON-NEXT:    usdot v0.4s, v3.16b, v5.16b
+; NEON-NEXT:    usdot v1.4s, v2.16b, v4.16b
+; NEON-NEXT:    ret
+;
+; SVE-LABEL: four_way_i8_i32_vl256_usdot:
+; SVE:       // %bb.0:
+; SVE-NEXT:    ldr z0, [x0]
+; SVE-NEXT:    ldr z1, [x1]
+; SVE-NEXT:    ldr z2, [x2]
+; SVE-NEXT:    usdot z0.s, z1.b, z2.b
+; SVE-NEXT:    mov z1.d, z0.d
+; SVE-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SVE-NEXT:    ret
+;
+; SME-LABEL: four_way_i8_i32_vl256_usdot:
+; SME:       // %bb.0:
+; SME-NEXT:    ptrue p0.s
+; SME-NEXT:    ldr z0, [x0]
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #1, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #1, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #2, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #2, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    ld1b { z1.s }, p0/z, [x1, #3, mul vl]
+; SME-NEXT:    ld1sb { z2.s }, p0/z, [x2, #3, mul vl]
+; SME-NEXT:    mla z0.s, p0/m, z2.s, z1.s
+; SME-NEXT:    mov z1.d, z0.d
+; SME-NEXT:    ext z1.b, z1.b, z0.b, #16
+; SME-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; SME-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; SME-NEXT:    ret
+  %acc = load <8 x i32>, ptr %accptr
+  %u = load <32 x i8>, ptr %uptr
+  %s = load <32 x i8>, ptr %sptr
+  %u.wide = zext <32 x i8> %u to <32 x i32>
+  %s.wide = sext <32 x i8> %s to <32 x i32>
+  %mult = mul nuw nsw <32 x i32> %s.wide, %u.wide
+  %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <32 x i32> %mult)
+  ret <8 x i32> %partial.reduce
+}
+
 ;
 ; Four-way dot (i16 -> i64)
 ;

Copy link

github-actions bot commented Jun 11, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good to me. I've just added the request for two more tests.

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@NickGuy-Arm NickGuy-Arm merged commit 3ea45a6 into llvm:main Jun 13, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants