-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VectorCombine] Fold binary op of reductions. #121567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Mikhail Gudim (mgudim) ChangesReplace binary of of two reductions with one reduction of the binary op applied to vectors. For example:
gets transformed to:
Full diff: https://github.com/llvm/llvm-project/pull/121567.diff 5 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 7a184a19d7c54a..42e816d527fcff 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1516,6 +1516,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -2376,19 +2379,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
}
- auto m_AddRdx = [](Value *&Vec) {
- return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
- };
- Value *V0, *V1;
- if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
- V0->getType() == V1->getType()) {
- // Difference of sums is sum of differences:
- // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
- Value *Sub = Builder.CreateSub(V0, V1);
- Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
- {Sub->getType()}, {Sub});
- return replaceInstUsesWith(I, Rdx);
- }
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
if (Constant *C = dyn_cast<Constant>(Op0)) {
Value *X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e576eea4ca36a1..d9fcaf124d459b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2388,6 +2388,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -3588,6 +3591,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
@@ -4713,6 +4719,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 3a074ee70dc487..99301d3e991f55 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
+ Instruction *foldBinopOfReductions(BinaryOperator &Inst);
Instruction *foldVectorSelect(SelectInst &Sel);
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f85a3c93651353..98023c5eb89e42 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -205,6 +205,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *X = foldVectorBinop(I))
return X;
+ if (Instruction *X = foldBinopOfReductions(I))
+ return replaceInstUsesWith(I, X);
+
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 934156f04f7fdd..12c53e8a0869f7 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2296,6 +2296,58 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
return nullptr;
}
+static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
+ switch (Opc) {
+ default:
+ break;
+ case Instruction::Add:
+ return Intrinsic::vector_reduce_add;
+ case Instruction::Mul:
+ return Intrinsic::vector_reduce_mul;
+ case Instruction::And:
+ return Intrinsic::vector_reduce_and;
+ case Instruction::Or:
+ return Intrinsic::vector_reduce_or;
+ case Instruction::Xor:
+ return Intrinsic::vector_reduce_xor;
+ }
+ return Intrinsic::num_intrinsics;
+}
+
+Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
+ IntrinsicInst *II0 = dyn_cast<IntrinsicInst>(Inst.getOperand(0));
+ if (!II0)
+ return nullptr;
+ IntrinsicInst *II1 = dyn_cast<IntrinsicInst>(Inst.getOperand(1));
+ if (!II1)
+ return nullptr;
+
+ Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
+ Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+ if (BinOpOpc == Instruction::Sub)
+ ReductionIID = Intrinsic::vector_reduce_add;
+
+ if (ReductionIID == Intrinsic::num_intrinsics)
+ return nullptr;
+ if (II0->getIntrinsicID() != ReductionIID)
+ return nullptr;
+ if (II1->getIntrinsicID() != ReductionIID)
+ return nullptr;
+
+ Value *V0 = II0->getArgOperand(0);
+ Value *V1 = II1->getArgOperand(0);
+ Type *VTy = V0->getType();
+ if (V1->getType() != VTy)
+ return nullptr;
+
+ Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
+ // if (auto *VectorInstBO = dyn_cast<BinaryOperator>(VectorBO))
+ // VectorInstBO->copyIRFlags(&Inst);
+
+ Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
+ return Rdx;
+}
+
/// Try to narrow the width of a binop if at least 1 operand is an extend of
/// of a value. This requires a potentially expensive known bits check to make
/// sure the narrow op does not overflow.
|
Test is precommited in: #121568 |
return nullptr; | ||
|
||
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1); | ||
// if (auto *VectorInstBO = dyn_cast<BinaryOperator>(VectorBO)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is possible to propagate these flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could present preserve disjoint
, but I don't think any others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks.
} | ||
|
||
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) { | ||
IntrinsicInst *II0 = dyn_cast<IntrinsicInst>(Inst.getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing one-use check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ReductionIID = Intrinsic::vector_reduce_add; | ||
|
||
if (ReductionIID == Intrinsic::num_intrinsics) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can drop this check as no intrin will be Instrinsic::no_intrinsic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote the code a little bit. From purely logical point of view, yes, we can drop this check. But this check is an early-exit check and also I think code looks much clearer with it.
71a72dd
to
50c3ac4
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
switch (Opc) { | ||
default: | ||
break; | ||
case Instruction::Add: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also fold or disjoint
with vector_reduce_add
, but its doubtful such a pattern shows up in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think or disjoint
should get canonicalized first to add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm? We do add
-> or disjoint
, not the other way around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also fold or disjoint with vector_reduce_add
what would be the benefit of doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None I think, was just a note
5a57194
to
d683355
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please wait for additional approval from other reviewers.
IIRC vertical vector ops are always cheaper than horizontal ones. It is suitable to perform this transformation in InstCombine.
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V); | ||
if (!II) | ||
return nullptr; | ||
if ((II->getIntrinsicID() == IID) && II->hasOneUse()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ((II->getIntrinsicID() == IID) && II->hasOneUse()) | |
if (II->getIntrinsicID() == IID && II->hasOneUse()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ead55cd
to
d4de22e
Compare
@goldsteinn @nikic does this look ready to merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will likely cause performance issues unless they are mitigated. There is a combine in the backend that does that same thing, but it gets guarded by shouldReassociateReduction and being in the backend allows the target to pattern recognize before it gets reassociated.
Can you provide a test to demonstrate the performance issue? As I said before, vertical vector ops are always cheaper than horizontal ones. If not, we can move it into VectorCombine and use InstructionCost to see whether the transformation is a win. |
Before this combine worked for I also find it surprising that this will cause a degradation on some targets. Can you please give an example. I also think |
Related patch: https://reviews.llvm.org/D141870 |
reassociate is correct here. For example, this patch converts |
got it, thanks |
Perhaps we could have both InstCombine and DAGCombiner transformations guarded by different target hooks? |
Yeah maybe. In my experience adds are much more common than the other reductions, and sub(reduce.add, reduce.add) would be less common than plain adds. Maybe it wasn't done for the more obvious
I believe that is mostly true, but some targets can have relatively efficient reductions. It looks like MVE will already convert back if needed https://godbolt.org/z/6YGqde37s.
We might not have costs (that don't go via getExtendedReductionCost / getMulAccReductionCost), as multiple-instructions patterns can be awkward to cost model well and it might not have come up before.
Instcombine, for better or worse, is considered a canonicalization pass that isn't controlled by target hooks. Vector combine is the place cost-modelled combines for vector is usually done. We could also consider this to be the canonical pattern and undo the transform in the backend if needed and we could do so reliably. |
Agree. |
@davemgreen Sorry, I am still not sure how to proceed. (1) Should I close this PR? or (2)
or (3) Move this to vector combine and guard it by target hook? I prefer (2) because this looks more canonical to me. |
@mgudim - I've had in mind a followup VectorCombine fold to this PR: define i32 @src(<4 x i32> %a0, <4 x i32> %a1) {
%r0 = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %a0)
%r1 = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %a1)
%r = and i32 %r0, %r1
ret i32 %r
}
define i32 @tgt1(<4 x i32> %a0, <4 x i32> %a1) {
%a = and <4 x i32> %a0, %a1
%r = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %a)
ret i32 %r
}
define i32 @tgt2(<4 x i32> %a0, <4 x i32> %a1) {
%a01 = shufflevector <4 x i32> %a0, <4 x i32> %a1, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%r = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> %a01)
ret i32 %r
} No objections if you want to move this PR to a cost driven fold in VectorCombine as well. |
I'll do that when I come back (in about a week). |
86ae393
to
a58cfeb
Compare
@davemgreen |
ce85315
to
8afa5b8
Compare
I looked at In all three cases the code looks quite similar. Also, it seems like I would have to repeat most of it in my patch too. Should we first come up with some API that simplifies |
6b25c35
to
9a976f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi - Thanks for adding the cost model. It looks descent, and seems to work quite well for MVE. Unfortunately there are cases where it isn't working as well at least for one case for udot on aarch64, as we don't always generate the reduction pattern until during instruction selection.
I was going to suggest I added a phase ordering test, but it looks like it needs LTO to go wrong as it only sees the vectors after loopvect+unroll+vector-combine, which do not usually happen in that order. I'll see if I can add one, but we might need to end up undoing this in the backend (or add a bailout from the target, or just leave it to the backend to do).
I think AArch64 / ARM may run into similar problem with other combines: i.e. those combines may brake the dot-product pattern for which there is native support. Now I think the best solution would be to call the result of this transformation as the canonical form and undo it in the backend? Or maybe insert intrinsics early enough so that other combines don't break the pattern? Or both? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still need to add cost driven tests
CostOfRedOperand0 + CostOfRedOperand1 + | ||
TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) + | ||
TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind); | ||
// TODO: remove this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup TODO and dbg, we usually just do something like:
LLVM_DEBUG(dbgs() << "Found mergeable reductions: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll remove that. I need it while I am adding more tests.
It looks like @davemgreen is still thinking what should be the right approach to avoid regressions.
Also, if I should continue with this, I still need advice on this: #121567 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
I might be able to add a cost for extending add reductions for AArch64. Whilst not actually matching the correct instruction (it ends up using a udot), it should still be correct in terms of costs. I might be able to look into that tomorrow. |
I added some quick cost-models for AArch64 extending reductions. I agree this isn't the most reliable thing, and having a more representative intrinsic might help. IMO this is better left for the backend, but not all architectures lower reduction nodes using sdag at the moment. |
@davemgreen I still need your input on this: #121567 (comment) Should I basically repeat all the code from LoopVectorizer here? In general, I find it weird that opt has to know about some pattern which exists on AArch64 (dot product) and try to preserve it |
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example: ``` %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0) %v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1) %res = add i32 %v0_red, %v1_red ``` gets transformed to: ``` %1 = add <16 x i32> %v0, %v1 %res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) ```
8fad569
to
20b783d
Compare
20b783d
to
bc1a198
Compare
Added test for ARM. @RKSimon do you think more tests are necessary? You mentioned adding a test for X86 but, X86 doesn't implement |
@davemgreen I added the cost model for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davemgreen I added the cost model for the reduce.add(ext(mul(ext(A), ext(B))) pattern. There are other dot product patterns. I would be happy to work on them if only you told me what is your preferred solution to #121567 (comment) Can you please let me know? Or maybe you can address this later yourself?
Thanks. I think this might just always be a bit of a bad thing for Arm/AArch64 (for add's at least, the others are fine), and is more likely to make things worse than better. The backend already handles it in the cases we need it, and controlling it with a cost model is a little unreliable. But I ran some tests on this version and didn't see any problems in the tests I tried. The code you have looks decent and all the cases that were a problem before no longer are.
(I'm not sure of the most ideal way to structure it, if I was writing the whole compiler myself I would be tempted to add a smuladd/umuladd reduction intrinsic so that the cost modelling can be more accurate, but that sounds like quite a lot of work and might not even perfectly solve the problem in every case without effectively doing type/operation legalization earlier. Undoing arbitrarily reassociated/combined add trees can be difficult but is maybe better because it might capture more cases. "We can add that in the future if we need it" is probably the best way forward).
@davemgreen Thanks for testing the patch. So is this LGTM from you? ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/65/builds/12655 Here is the relevant piece of the build log for the reference
|
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example:
gets transformed to: