Skip to content

[AArch64] Improve codegen for some fixed-width partial reductions #126529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 25, 2025

Conversation

david-arm
Copy link
Contributor

This patch teaches optimizeExtendOrTruncateConversion to bail out
if the user of a zero-extend is a partial reduction intrinsic
that we know will get lowered efficiently to a udot instruction.

@llvmbot
Copy link
Member

llvmbot commented Feb 10, 2025

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

Changes

This patch teaches optimizeExtendOrTruncateConversion to bail out
if the user of a zero-extend is a partial reduction intrinsic
that we know will get lowered efficiently to a udot instruction.


Full diff: https://github.com/llvm/llvm-project/pull/126529.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+6-1)
  • (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+242-3)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7f6da4fa38f90a7..a7e576dd2bb35c5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16866,9 +16866,14 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
     // mul(zext(i8), sext) can be transformed into smull(zext, sext) which
     // performs one extend implicitly. If DstWidth is at most 4 * SrcWidth, at
     // most one extra extend step is needed and using tbl is not profitable.
+    // Similarly, bail out if partial_reduce(acc, zext(i8)) can be lowered to a
+    // udot instruction.
     if (SrcWidth * 4 <= DstWidth && I->hasOneUser()) {
       auto *SingleUser = cast<Instruction>(*I->user_begin());
-      if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))))
+      if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))) ||
+          (isa<IntrinsicInst>(SingleUser) &&
+           !shouldExpandPartialReductionIntrinsic(
+               cast<IntrinsicInst>(SingleUser))))
         return false;
     }
 
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 9ece9edb843439d..b5f25e32439ee7e 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -26,6 +26,66 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
+; CHECK-DOT-LABEL: udot_in_loop:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT:    mov x8, xzr
+; CHECK-DOT-NEXT:  .LBB3_1: // %vector.body
+; CHECK-DOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT:    ldr q2, [x0, x8]
+; CHECK-DOT-NEXT:    ldr q3, [x1, x8]
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    add x8, x8, #16
+; CHECK-DOT-NEXT:    udot v1.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    cmp x8, #16
+; CHECK-DOT-NEXT:    b.ne .LBB3_1
+; CHECK-DOT-NEXT:  // %bb.2: // %end
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_in_loop:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NODOT-NEXT:    mov x8, xzr
+; CHECK-NODOT-NEXT:  .LBB3_1: // %vector.body
+; CHECK-NODOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NODOT-NEXT:    ldr q0, [x0, x8]
+; CHECK-NODOT-NEXT:    ldr q2, [x1, x8]
+; CHECK-NODOT-NEXT:    add x8, x8, #16
+; CHECK-NODOT-NEXT:    cmp x8, #16
+; CHECK-NODOT-NEXT:    umull v3.8h, v0.8b, v2.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v0.16b, v2.16b
+; CHECK-NODOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v1.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v4.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v1.4s, v1.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v4.4s, v2.8h
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    b.ne .LBB3_1
+; CHECK-NODOT-NEXT:  // %bb.2: // %end
+; CHECK-NODOT-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = zext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = zext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-DOT-LABEL: udot_narrow:
 ; CHECK-DOT:       // %bb.0:
@@ -128,6 +188,68 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
+; CHECK-NOI8MM-LABEL: usdot_in_loop:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NOI8MM-NEXT:    mov x8, xzr
+; CHECK-NOI8MM-NEXT:  .LBB10_1: // %vector.body
+; CHECK-NOI8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NOI8MM-NEXT:    ldr q0, [x0, x8]
+; CHECK-NOI8MM-NEXT:    ldr q2, [x1, x8]
+; CHECK-NOI8MM-NEXT:    add x8, x8, #16
+; CHECK-NOI8MM-NEXT:    cmp x8, #16
+; CHECK-NOI8MM-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v4.8h, v0.16b, #0
+; CHECK-NOI8MM-NEXT:    ushll v5.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NOI8MM-NEXT:    smlal v1.4s, v3.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    smull v6.4s, v4.4h, v2.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v1.4s, v4.8h, v2.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v6.4s, v3.8h, v5.8h
+; CHECK-NOI8MM-NEXT:    add v1.4s, v6.4s, v1.4s
+; CHECK-NOI8MM-NEXT:    b.ne .LBB10_1
+; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: usdot_in_loop:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    mov x8, xzr
+; CHECK-I8MM-NEXT:  .LBB10_1: // %vector.body
+; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-I8MM-NEXT:    add x8, x8, #16
+; CHECK-I8MM-NEXT:    usdot v1.4s, v3.16b, v2.16b
+; CHECK-I8MM-NEXT:    cmp x8, #16
+; CHECK-I8MM-NEXT:    b.ne .LBB10_1
+; CHECK-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-I8MM-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = sext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = zext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-LABEL: usdot_narrow:
 ; CHECK-NOI8MM:       // %bb.0:
@@ -175,13 +297,75 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.4s, v2.16b, v1.16b
 ; CHECK-I8MM-NEXT:    ret
-  %u.wide = sext <16 x i8> %u to <16 x i32>
-  %s.wide = zext <16 x i8> %s to <16 x i32>
-  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %s.wide = sext <16 x i8> %u to <16 x i32>
+  %u.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %u.wide, %s.wide
   %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
+; CHECK-NOI8MM-LABEL: sudot_in_loop:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NOI8MM-NEXT:    mov x8, xzr
+; CHECK-NOI8MM-NEXT:  .LBB13_1: // %vector.body
+; CHECK-NOI8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NOI8MM-NEXT:    ldr q0, [x0, x8]
+; CHECK-NOI8MM-NEXT:    ldr q2, [x1, x8]
+; CHECK-NOI8MM-NEXT:    add x8, x8, #16
+; CHECK-NOI8MM-NEXT:    cmp x8, #16
+; CHECK-NOI8MM-NEXT:    ushll v3.8h, v0.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v4.8h, v0.16b, #0
+; CHECK-NOI8MM-NEXT:    sshll v5.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NOI8MM-NEXT:    smlal v1.4s, v3.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    smull v6.4s, v4.4h, v2.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v1.4s, v4.8h, v2.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v6.4s, v3.8h, v5.8h
+; CHECK-NOI8MM-NEXT:    add v1.4s, v6.4s, v1.4s
+; CHECK-NOI8MM-NEXT:    b.ne .LBB13_1
+; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: sudot_in_loop:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    mov x8, xzr
+; CHECK-I8MM-NEXT:  .LBB13_1: // %vector.body
+; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-I8MM-NEXT:    add x8, x8, #16
+; CHECK-I8MM-NEXT:    usdot v1.4s, v2.16b, v3.16b
+; CHECK-I8MM-NEXT:    cmp x8, #16
+; CHECK-I8MM-NEXT:    b.ne .LBB13_1
+; CHECK-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-I8MM-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = zext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = sext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-LABEL: sudot_narrow:
 ; CHECK-NOI8MM:       // %bb.0:
@@ -389,6 +573,61 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @udot_no_bin_op_in_loop(ptr %p){
+; CHECK-DOT-LABEL: udot_no_bin_op_in_loop:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    mov x8, xzr
+; CHECK-DOT-NEXT:  .LBB20_1: // %vector.body
+; CHECK-DOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT:    ldr q3, [x0, x8]
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    add x8, x8, #16
+; CHECK-DOT-NEXT:    cmp x8, #16
+; CHECK-DOT-NEXT:    udot v1.4s, v3.16b, v2.16b
+; CHECK-DOT-NEXT:    b.ne .LBB20_1
+; CHECK-DOT-NEXT:  // %bb.2: // %end
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_in_loop:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NODOT-NEXT:    mov x8, xzr
+; CHECK-NODOT-NEXT:  .LBB20_1: // %vector.body
+; CHECK-NODOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NODOT-NEXT:    ldr q0, [x0, x8]
+; CHECK-NODOT-NEXT:    add x8, x8, #16
+; CHECK-NODOT-NEXT:    cmp x8, #16
+; CHECK-NODOT-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.8h, v0.16b, #0
+; CHECK-NODOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v1.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v4.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    uaddw2 v1.4s, v1.4s, v2.8h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v4.4s, v3.8h
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    b.ne .LBB20_1
+; CHECK-NODOT-NEXT:  // %bb.2: // %end
+; CHECK-NODOT-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep = getelementptr i8, ptr %p, i64 %index
+  %load = load <16 x i8>, ptr %gep, align 16
+  %load.wide = zext <16 x i8> %load to <16 x i32>
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %load.wide)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
 ; CHECK-DOT-LABEL: sdot_no_bin_op:
 ; CHECK-DOT:       // %bb.0:

if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))))
if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))) ||
(isa<IntrinsicInst>(SingleUser) &&
!shouldExpandPartialReductionIntrinsic(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently shouldExpandPartialReductionIntrinsic does not check whether the target actually has support for the udot/sdot, but the loop vectoriser should not be generating partial reduction intrinsic calls in that case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is missing a check that the IntrinsicID is a partial reduction intrinsic. I know this is checked in shouldExpandPartialReductionIntrinsic, but I think that function should assert that the intrinsic is a partial reduction intrinsic, rather than returning false when it isn't (or alternatively, just be passed in the types rather than the IntrinsicInst*, so that it can make its judgement based on the types).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's a good suggestion. I've added an assert to shouldExpandPartialReductionIntrinsic for now, as I'd prefer not to change the hook in this patch if possible. Also, I can see a possible argument in future for the hook needing the instruction to provide context.

Also, there was a problem with my previous version anyway because the extended value could have been used for the accumulator, whereas in this latest version I'm explicitly checking it's the second arg.

%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
%s.wide = sext <16 x i8> %u to <16 x i32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the variable names were incorrect so I've corrected them while I'm here!

@@ -26,6 +26,66 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
ret <4 x i32> %partial.reduce
}

define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the loop required here? From what I can see, optimizeExtendOrTruncateConversion is also called for blocks that are not in loops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that AArch64TargetLowering::optimizeExtendOrTruncateConversion bails out if the instruction does not live in a loop header, i.e.

  // Try to optimize conversions using tbl. This requires materializing constant
  // index vectors, which can increase code size and add loads. Skip the
  // transform unless the conversion is in a loop block guaranteed to execute
  // and we are not optimizing for size.
  Function *F = I->getParent()->getParent();
  if (!L || L->getHeader() != I->getParent() || F->hasMinSize() ||
      F->hasOptSize())
    return false;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question and not something I expect you to change in this patch, but should the code you're changing be done regardless of whether it's in a loop? The smull(zext, sext) case seems like it would always be a win. Perhaps the same is true for the partial.reduce case, even though to use an sdot/udot it may need to materialise a splat(1) vector, the constant should be cheap to materialize (and easy to CSE if there are multiple uses).

This patch teaches optimizeExtendOrTruncateConversion to bail out
if the user of a zero-extend is a partial reduction intrinsic
that we know will get lowered efficiently to a udot instruction.
@david-arm
Copy link
Contributor Author

Gentle ping. :)

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK if there are no other comments. LGTM

@david-arm david-arm merged commit 85cf958 into llvm:main Feb 25, 2025
11 checks passed
@david-arm david-arm deleted the udot_loop branch February 27, 2025 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants