Skip to content

Commit d683355

Browse files
author
Mikhail Gudim
committed
[InstCombine] Fold binary op of reductions.
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example: ``` %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0) %v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1) %res = add i32 %v0_red, %v1_red ``` gets transformed to: ``` %1 = add <16 x i32> %v0, %v1 %res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) ```
1 parent bfa711a commit d683355

File tree

5 files changed

+75
-13
lines changed

5 files changed

+75
-13
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15281528
if (Instruction *X = foldVectorBinop(I))
15291529
return X;
15301530

1531+
if (Instruction *X = foldBinopOfReductions(I))
1532+
return replaceInstUsesWith(I, X);
1533+
15311534
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15321535
return Phi;
15331536

@@ -2378,19 +2381,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23782381
}
23792382
}
23802383

2381-
auto m_AddRdx = [](Value *&Vec) {
2382-
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2383-
};
2384-
Value *V0, *V1;
2385-
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2386-
V0->getType() == V1->getType()) {
2387-
// Difference of sums is sum of differences:
2388-
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2389-
Value *Sub = Builder.CreateSub(V0, V1);
2390-
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2391-
{Sub->getType()}, {Sub});
2392-
return replaceInstUsesWith(I, Rdx);
2393-
}
2384+
if (Instruction *X = foldBinopOfReductions(I))
2385+
return replaceInstUsesWith(I, X);
23942386

23952387
if (Constant *C = dyn_cast<Constant>(Op0)) {
23962388
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23992399
if (Instruction *X = foldVectorBinop(I))
24002400
return X;
24012401

2402+
if (Instruction *X = foldBinopOfReductions(I))
2403+
return replaceInstUsesWith(I, X);
2404+
24022405
if (Instruction *Phi = foldBinopWithPhiOperands(I))
24032406
return Phi;
24042407

@@ -3585,6 +3588,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35853588
if (Instruction *X = foldVectorBinop(I))
35863589
return X;
35873590

3591+
if (Instruction *X = foldBinopOfReductions(I))
3592+
return replaceInstUsesWith(I, X);
3593+
35883594
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35893595
return Phi;
35903596

@@ -4696,6 +4702,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46964702
if (Instruction *X = foldVectorBinop(I))
46974703
return X;
46984704

4705+
if (Instruction *X = foldBinopOfReductions(I))
4706+
return replaceInstUsesWith(I, X);
4707+
46994708
if (Instruction *Phi = foldBinopWithPhiOperands(I))
47004709
return Phi;
47014710

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
597597

598598
/// Canonicalize the position of binops relative to shufflevector.
599599
Instruction *foldVectorBinop(BinaryOperator &Inst);
600+
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
600601
Instruction *foldVectorSelect(SelectInst &Sel);
601602
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
602603

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
201201
if (Instruction *X = foldVectorBinop(I))
202202
return X;
203203

204+
if (Instruction *X = foldBinopOfReductions(I))
205+
return replaceInstUsesWith(I, X);
206+
204207
if (Instruction *Phi = foldBinopWithPhiOperands(I))
205208
return Phi;
206209

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,6 +2313,63 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23132313
return nullptr;
23142314
}
23152315

2316+
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2317+
switch (Opc) {
2318+
default:
2319+
break;
2320+
case Instruction::Add:
2321+
return Intrinsic::vector_reduce_add;
2322+
case Instruction::Mul:
2323+
return Intrinsic::vector_reduce_mul;
2324+
case Instruction::And:
2325+
return Intrinsic::vector_reduce_and;
2326+
case Instruction::Or:
2327+
return Intrinsic::vector_reduce_or;
2328+
case Instruction::Xor:
2329+
return Intrinsic::vector_reduce_xor;
2330+
}
2331+
return Intrinsic::not_intrinsic;
2332+
}
2333+
2334+
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2335+
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2336+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2337+
if (BinOpOpc == Instruction::Sub)
2338+
ReductionIID = Intrinsic::vector_reduce_add;
2339+
if (ReductionIID == Intrinsic::not_intrinsic)
2340+
return nullptr;
2341+
2342+
auto checkIntrinsicAndGetItsArgument = [](Value *V,
2343+
Intrinsic::ID IID) -> Value * {
2344+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
2345+
if (!II)
2346+
return nullptr;
2347+
if ((II->getIntrinsicID() == IID) && II->hasOneUse())
2348+
return II->getArgOperand(0);
2349+
return nullptr;
2350+
};
2351+
2352+
Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
2353+
if (!V0)
2354+
return nullptr;
2355+
Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
2356+
if (!V1)
2357+
return nullptr;
2358+
2359+
Type *VTy = V0->getType();
2360+
if (V1->getType() != VTy)
2361+
return nullptr;
2362+
2363+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2364+
2365+
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
2366+
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
2367+
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
2368+
2369+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2370+
return Rdx;
2371+
}
2372+
23162373
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23172374
/// of a value. This requires a potentially expensive known bits check to make
23182375
/// sure the narrow op does not overflow.

0 commit comments

Comments
 (0)