Skip to content

[InstCombine] Fold vector.reduce.op(vector.reverse(X)) -> vector.reduce.op(X) #91743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 17, 2024

Conversation

david-arm
Copy link
Contributor

For all of the following reductions:

vector.reduce.or
vector.reduce.and
vector.reduce.xor
vector.reduce.add
vector.reduce.mul
vector.reduce.umin
vector.reduce.umax
vector.reduce.smin
vector.reduce.smax
vector.reduce.fmin
vector.reduce.fmax

if the input operand is the result of a vector.reverse then we can perform a reduction on the vector.reverse input instead since the answer is the same. If the reassociation is permitted we can also do the same folds for these:

vector.reduce.fadd
vector.reduce.fmul

@llvmbot
Copy link
Member

llvmbot commented May 10, 2024

@llvm/pr-subscribers-llvm-transforms

Author: David Sherwood (david-arm)

Changes

For all of the following reductions:

vector.reduce.or
vector.reduce.and
vector.reduce.xor
vector.reduce.add
vector.reduce.mul
vector.reduce.umin
vector.reduce.umax
vector.reduce.smin
vector.reduce.smax
vector.reduce.fmin
vector.reduce.fmax

if the input operand is the result of a vector.reverse then we can perform a reduction on the vector.reverse input instead since the answer is the same. If the reassociation is permitted we can also do the same folds for these:

vector.reduce.fadd
vector.reduce.fmul


Full diff: https://github.com/llvm/llvm-project/pull/91743.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+66-1)
  • (modified) llvm/test/Transforms/InstCombine/vector-logical-reductions.ll (+72)
  • (modified) llvm/test/Transforms/InstCombine/vector-reductions.ll (+162)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index d7433ad3599f9..fb2d62e23493a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3222,6 +3222,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // %res = cmp eq iReduxWidth %val, 11111
     Value *Arg = II->getArgOperand(0);
     Value *Vect;
+    // When doing a logical reduction of a reversed operand the result is
+    // identical to reducing the unreversed operand.
+    if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+      Value *Res = IID == Intrinsic::vector_reduce_or
+                       ? Builder.CreateOrReduce(Vect)
+                       : Builder.CreateAndReduce(Vect);
+      return replaceInstUsesWith(CI, Res);
+    }
     if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
       if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
         if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3253,6 +3261,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       // Trunc(ctpop(bitcast <n x i1> to in)).
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing an integer add reduction of a reversed operand the result
+      // is identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateAddReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3281,6 +3295,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_add(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a xor reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateXorReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3304,6 +3324,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_and(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a mul reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = Builder.CreateMulReduce(Vect);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3328,6 +3354,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   ?ext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a min/max reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = IID == Intrinsic::vector_reduce_umin
+                         ? Builder.CreateIntMinReduce(Vect, false)
+                         : Builder.CreateIntMaxReduce(Vect, false);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3363,6 +3397,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       //   zext(vector_reduce_{and,or}(<n x i1>))
       Value *Arg = II->getArgOperand(0);
       Value *Vect;
+      // When doing a min/max reduction of a reversed operand the result is
+      // identical to reducing the unreversed operand.
+      if (match(Arg, m_VecReverse(m_Value(Vect)))) {
+        Value *Res = IID == Intrinsic::vector_reduce_smin
+                         ? Builder.CreateIntMinReduce(Vect, true)
+                         : Builder.CreateIntMaxReduce(Vect, true);
+        return replaceInstUsesWith(CI, Res);
+      }
       if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
         if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
           if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3394,8 +3436,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
                                 : 0;
     Value *Arg = II->getArgOperand(ArgIdx);
     Value *V;
+
+    if (!CanBeReassociated)
+      break;
+
+    if (match(Arg, m_VecReverse(m_Value(V)))) {
+      Value *Res;
+      switch (IID) {
+      case Intrinsic::vector_reduce_fadd:
+        Res = Builder.CreateFAddReduce(II->getArgOperand(0), V);
+        break;
+      case Intrinsic::vector_reduce_fmul:
+        Res = Builder.CreateFMulReduce(II->getArgOperand(0), V);
+        break;
+      case Intrinsic::vector_reduce_fmin:
+        Res = Builder.CreateFPMinReduce(V);
+        break;
+      case Intrinsic::vector_reduce_fmax:
+        Res = Builder.CreateFPMaxReduce(V);
+        break;
+      }
+      return replaceInstUsesWith(CI, Res);
+    }
+
     ArrayRef<int> Mask;
-    if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated ||
+    if (!isa<FixedVectorType>(Arg->getType()) ||
         !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
         !cast<ShuffleVectorInst>(Arg)->isSingleSource())
       break;
diff --git a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
index 9bb307ebf71e8..da4a0ca754680 100644
--- a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
@@ -21,5 +21,77 @@ define i1 @reduction_logical_and(<4 x i1> %x) {
   ret i1 %r
 }
 
+define i1 @reduction_logical_or_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_or_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_or_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[RED:%.*]] = icmp ne i2 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.or.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_and_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[RED:%.*]] = icmp eq i2 [[TMP1]], -1
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
+; CHECK-NEXT:    [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %rev)
+  ret i1 %red
+}
+
+define i1 @reduction_logical_xor_reverse_v2i1(<2 x i1> %p) {
+; CHECK-LABEL: @reduction_logical_xor_reverse_v2i1(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
+; CHECK-NEXT:    [[TMP2:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP1]])
+; CHECK-NEXT:    [[RED:%.*]] = trunc i2 [[TMP2]] to i1
+; CHECK-NEXT:    ret i1 [[RED]]
+;
+  %rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
+  %red = call i1 @llvm.vector.reduce.xor.v2i1(<2 x i1> %rev)
+  ret i1 %red
+}
+
 declare i1 @llvm.vector.reduce.or.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.or.v2i1(<2 x i1>)
 declare i1 @llvm.vector.reduce.and.v4i1(<4 x i1>)
+declare i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.and.v2i1(<2 x i1>)
+declare i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1>)
+declare i1 @llvm.vector.reduce.xor.v2i1(<2 x i1>)
+declare <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1>)
+declare <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1>)
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 2614ffd386952..3e2a23a5ef64e 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -3,12 +3,29 @@
 
 declare float @llvm.vector.reduce.fadd.f32.v4f32(float, <4 x float>)
 declare float @llvm.vector.reduce.fadd.f32.v8f32(float, <8 x float>)
+declare float @llvm.vector.reduce.fmul.f32.nxv4f32(float, <vscale x 4 x float>)
+declare float @llvm.vector.reduce.fmin.f32.v4f32(float, <4 x float>)
+declare float @llvm.vector.reduce.fmax.f32.nxv4f32(float, <vscale x 4 x float>)
 declare void @use_f32(float)
 
 declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32>)
 declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>)
 declare void @use_i32(i32)
 
+declare i32 @llvm.vector.reduce.mul.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32>)
+
+declare i32 @llvm.vector.reduce.smin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32>)
+declare i32 @llvm.vector.reduce.umin.v4i32(<4 x i32>)
+declare i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32>)
+
+declare <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32>)
+declare <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float>)
+declare <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32>)
+declare <4 x float> @llvm.vector.reverse.v4f32(<4 x float>)
+
 define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
 ; CHECK-LABEL: @diff_of_sums_v4f32(
 ; CHECK-NEXT:    [[TMP1:%.*]] = fsub reassoc nsz <4 x float> [[V0:%.*]], [[V1:%.*]]
@@ -22,6 +39,71 @@ define float @diff_of_sums_v4f32(float %a0, <4 x float> %v0, float %a1, <4 x flo
   ret float %r
 }
 
+define float @reassoc_sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @reassoc_sum_of_reverse_v4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call reassoc float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+  ret float %red
+}
+
+define float @reassoc_mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @reassoc_mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 1.000000e+00, <vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call reassoc float @llvm.vector.reduce.fmul.nxv4f32(float 1.0, <vscale x 4 x float> %rev)
+  ret float %red
+}
+
+define float @fmax_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @fmax_of_reverse_v4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> %rev)
+  ret float %red
+}
+
+define float @fmin_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @fmin_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> %rev)
+  ret float %red
+}
+
+; negative test - fadd cannot be folded with reverse due to lack of reassoc
+define float @sum_of_reverse_v4f32(<4 x float> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4f32(
+; CHECK-NEXT:    [[REV:%.*]] = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[REV]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <4 x float> @llvm.vector.reverse.v4f32(<4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fadd.v4f32(float zeroinitializer, <4 x float> %rev)
+  ret float %red
+}
+
+; negative test - fmul cannot be folded with reverse due to lack of reassoc
+define float @mul_reduction_of_reverse_nxv4f32(<vscale x 4 x float> %v0) {
+; CHECK-LABEL: @mul_reduction_of_reverse_nxv4f32(
+; CHECK-NEXT:    [[REV:%.*]] = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> [[V0:%.*]])
+; CHECK-NEXT:    [[RED:%.*]] = call float @llvm.vector.reduce.fmul.nxv4f32(float 0.000000e+00, <vscale x 4 x float> [[REV]])
+; CHECK-NEXT:    ret float [[RED]]
+;
+  %rev = call <vscale x 4 x float> @llvm.vector.reverse.nxv4f32(<vscale x 4 x float> %v0)
+  %red = call float @llvm.vector.reduce.fmul.nxv4f32(float zeroinitializer, <vscale x 4 x float> %rev)
+  ret float %red
+}
+
+
 ; negative test - fsub must allow reassociation
 
 define float @diff_of_sums_v4f32_fmf(float %a0, <4 x float> %v0, float %a1, <4 x float> %v1) {
@@ -98,6 +180,86 @@ define i32 @diff_of_sums_v4i32(<4 x i32> %v0, <4 x i32> %v1) {
   ret i32 %r
 }
 
+define i32 @sum_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @sum_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @sum_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @mul_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @mul_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.mul.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @smin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @smin_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.smin.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @smax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @smax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.smax.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @umin_reduce_of_reverse_v4i32(<4 x i32> %v0) {
+; CHECK-LABEL: @umin_reduce_of_reverse_v4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <4 x i32> @llvm.vector.reverse.v4i32(<4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.umin.v4i32(<4 x i32> %rev)
+  ret i32 %red
+}
+
+define i32 @umax_reduce_of_reverse_nxv4i32(<vscale x 4 x i32> %v0) {
+; CHECK-LABEL: @umax_reduce_of_reverse_nxv4i32(
+; CHECK-NEXT:    [[RED:%.*]] = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> [[V0:%.*]])
+; CHECK-NEXT:    ret i32 [[RED]]
+;
+  %rev = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> %v0)
+  %red = call i32 @llvm.vector.reduce.umax.nxv4i32(<vscale x 4 x i32> %rev)
+  ret i32 %red
+}
+
 ; negative test - extra uses could create extra instructions
 
 define i32 @diff_of_sums_v4i32_extra_use1(<4 x i32> %v0, <4 x i32> %v1) {

davemgreen added a commit to davemgreen/llvm-project that referenced this pull request May 10, 2024
As a small addition to llvm#91743, this uses copysign to produce the correct sign
for zero when converting frem to div/trunc/mul when we do not know that the
input is positive (or we care about sign bits). The copysign lets us get the
sign of zero correct.

In testing, the only case this produced different results thant fmod was:
frem -inf, 4.0 -> nan vs -nan
}
return replaceInstUsesWith(CI, Res);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a general easier way to implement this would be create a helper in InstuctionCombiner. The code for all them can simply be replaceOperand of the reduce with the pre-reversed vec. That will also make it easier to expand (for example any shuffle that preserves all elements).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing shuffle code is down below, and I agree it would be nice for them to be treated similarly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense. I'll have a go!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem with this approach. In the Intrinsic::vector_reduce_or case if I use the same approach as we do for fadd, etc. then what happens is on the first iteration we hit this code:

      replaceUse(II->getOperandUse(ArgIdx), V);
      return nullptr;

and then on the second iteration we still hit the existing code in Intrinsic::vector_reduce_or:

     if (!IsReverse && match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
       ...
          return replaceInstUsesWith(CI, Res);

because on the 2nd iteration the reduce_or input is no longer a reverse intrinsic. InstCombine quite rightly gets unhappy with this error:

LLVM ERROR: Instruction Combining did not reach a fixpoint after 1 iterations

Essentially it looks like the only way I can stop the iterations is to return replaceInstUsesWith when we first spot the pattern (reduce.or(vector.reverse())) unless you have other suggestions that may help?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it helps I can post the broken patch on this PR so you can see the code and find out if I've done something wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the solution is to return the original instruction instead of nullptr as this might break the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the helper and the InstCombiner tests all pass. I had to return the instruction in order to stop processing for all cases except fadd,fmul,fmin and fmax. I could probably now refactor all the reduction cases in the switch statement to remove the fallthroughs, but I thought it might be best to do this in a follow-on patch to make it easier to review.

Comment on lines 1438 to 1453
Instruction *InstCombinerImpl::simplifyReductionOfShuffle(IntrinsicInst *II) {
Intrinsic::ID IID = II->getIntrinsicID();
bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
IID != Intrinsic::vector_reduce_fmul) ||
II->hasAllowReassoc();

if (!CanBeReassociated)
return nullptr;

const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
IID == Intrinsic::vector_reduce_fmul)
? 1
: 0;
Value *Arg = II->getArgOperand(ArgIdx);
Value *V;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought but perhaps this will be cleaner is you created something like simplifyReductionOperand(Value* Op, bool CanReorderLanes)? So you'd end up with either:

if (auto *NewOp = simplifyReductionOperand(II->getArgOperand(0), true) {
  replaceUse(II->getOperandUse(0), NewOp);
  return nullptr;
}

if (auto *NewOp = simplifyReductionOperand(II->getArgOperand(1), II->hasAllowReassoc()) {
  replaceUse(II->getOperandUse(1), NewOp);
  return nullptr;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not quite what I meant. I was more thinking simplifyReductionOperand would take the reduction's operand rather than the reduction itself. This way the reduction specific parts remain at the call site.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I was focussed on the new flag and missed that bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

david-arm added 2 commits May 15, 2024 13:54
…ce.op(X)

For all of the following reductions:

vector.reduce.or
vector.reduce.and
vector.reduce.xor
vector.reduce.add
vector.reduce.mul
vector.reduce.umin
vector.reduce.umax
vector.reduce.smin
vector.reduce.smax
vector.reduce.fmin
vector.reduce.fmax

if the input operand is the result of a vector.reverse then we
can perform a reduction on the vector.reverse input instead since
the answer is the same. If the reassociation is permitted we can
also do the same folds for these:

vector.reduce.fadd
vector.reduce.fmul
* Let all of the vector.reduce.op variants now call a new helper
called simplifyReductionOfShuffle, which deals with both
vector.reverse and shufflevector operations.
@david-arm
Copy link
Contributor Author

Rebase

david-arm added 2 commits May 15, 2024 13:55
... and pass in a flag indicating whether we can reorder lanes.
* Pass reduction operand to simplifyReductionOperand instead and
return the simplified operand.
Comment on lines 1438 to 1439
Value *InstCombinerImpl::simplifyReductionOperand(Value *Arg,
bool CanReorderLanes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy either way but I suppose at this moment in time this could just be a static function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -3364,6 +3421,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// zext(vector_reduce_{and,or}(<n x i1>))
Value *Arg = II->getArgOperand(0);
Value *Vect;

if (Value *NewOp = simplifyReductionOperand(Arg, true)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Here and above, style is /*CanReorderLanes=*/true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


// Can remove shuffle iff just shuffled elements, no repeats, undefs, or
// other changes.
return UsedIndices.all() ? V : nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think the shuffle check could be simplified with getShuffleDemandedElts.

Then the check would just be:

if(DemandedLHS.isAllOnes() && DemandedRHS.isZero()) {
  return LHS;
}
if(DemandedRHS.isAllOnes() && DemandedLHS.isZero()) {
  return RHS;
}
return nullptr;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is not a NFC change. I am not quite sure how to use getShuffleDemandedElts, but I'm pretty sure your example above now permits use of a RHS whereas previously it didn't (see !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask)))). I would prefer to limit this to refactoring only as it's not relevant to this patch. I can keep the if(DemandedLHS.isAllOnes() ... check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look at using getShuffleDemandedElts and the code didn't immediately look much better as I still need to keep most of the existing checks, i.e. check for isSingleSource and so on. For now, I'd prefer to leave the code as it is, but I'm happy to revisit this in a follow-on patch. Just out of curiosity, have you seen any examples where we pass through the 2nd operand of the shufflevector? i.e. reduce_or(shufflevector(<4 x i32> %a, <4 x i32> %b, <4 x i32> <7, 5, 6, 4>)). I'd expect this to be canonicalised to reduce_or(shufflevector(<4 x i32> %a, <4 x i32> undef, <4 x i32> <3, 1, 2, 0>))

@david-arm david-arm merged commit 0ad275c into llvm:main May 17, 2024
3 of 4 checks passed
@david-arm david-arm deleted the reduce_reverse branch June 27, 2024 12:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants