Skip to content

Commit 71a72dd

Browse files
author
Mikhail Gudim
committed
[InstCombine] Fold binary op of reductions.
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example: ``` %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0) %v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1) %res = add i32 %v0_red, %v1_red ``` gets transformed to: ``` %1 = add <16 x i32> %v0, %v1 %res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) ```
1 parent 207e485 commit 71a72dd

File tree

5 files changed

+70
-13
lines changed

5 files changed

+70
-13
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15161516
if (Instruction *X = foldVectorBinop(I))
15171517
return X;
15181518

1519+
if (Instruction *X = foldBinopOfReductions(I))
1520+
return replaceInstUsesWith(I, X);
1521+
15191522
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15201523
return Phi;
15211524

@@ -2376,19 +2379,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23762379
}
23772380
}
23782381

2379-
auto m_AddRdx = [](Value *&Vec) {
2380-
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2381-
};
2382-
Value *V0, *V1;
2383-
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2384-
V0->getType() == V1->getType()) {
2385-
// Difference of sums is sum of differences:
2386-
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2387-
Value *Sub = Builder.CreateSub(V0, V1);
2388-
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2389-
{Sub->getType()}, {Sub});
2390-
return replaceInstUsesWith(I, Rdx);
2391-
}
2382+
if (Instruction *X = foldBinopOfReductions(I))
2383+
return replaceInstUsesWith(I, X);
23922384

23932385
if (Constant *C = dyn_cast<Constant>(Op0)) {
23942386
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,6 +2388,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23882388
if (Instruction *X = foldVectorBinop(I))
23892389
return X;
23902390

2391+
if (Instruction *X = foldBinopOfReductions(I))
2392+
return replaceInstUsesWith(I, X);
2393+
23912394
if (Instruction *Phi = foldBinopWithPhiOperands(I))
23922395
return Phi;
23932396

@@ -3588,6 +3591,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35883591
if (Instruction *X = foldVectorBinop(I))
35893592
return X;
35903593

3594+
if (Instruction *X = foldBinopOfReductions(I))
3595+
return replaceInstUsesWith(I, X);
3596+
35913597
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35923598
return Phi;
35933599

@@ -4713,6 +4719,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
47134719
if (Instruction *X = foldVectorBinop(I))
47144720
return X;
47154721

4722+
if (Instruction *X = foldBinopOfReductions(I))
4723+
return replaceInstUsesWith(I, X);
4724+
47164725
if (Instruction *Phi = foldBinopWithPhiOperands(I))
47174726
return Phi;
47184727

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
594594

595595
/// Canonicalize the position of binops relative to shufflevector.
596596
Instruction *foldVectorBinop(BinaryOperator &Inst);
597+
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
597598
Instruction *foldVectorSelect(SelectInst &Sel);
598599
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
599600

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
205205
if (Instruction *X = foldVectorBinop(I))
206206
return X;
207207

208+
if (Instruction *X = foldBinopOfReductions(I))
209+
return replaceInstUsesWith(I, X);
210+
208211
if (Instruction *Phi = foldBinopWithPhiOperands(I))
209212
return Phi;
210213

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,6 +2296,58 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
22962296
return nullptr;
22972297
}
22982298

2299+
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2300+
switch (Opc) {
2301+
default:
2302+
break;
2303+
case Instruction::Add:
2304+
return Intrinsic::vector_reduce_add;
2305+
case Instruction::Mul:
2306+
return Intrinsic::vector_reduce_mul;
2307+
case Instruction::And:
2308+
return Intrinsic::vector_reduce_and;
2309+
case Instruction::Or:
2310+
return Intrinsic::vector_reduce_or;
2311+
case Instruction::Xor:
2312+
return Intrinsic::vector_reduce_xor;
2313+
}
2314+
return Intrinsic::num_intrinsics;
2315+
}
2316+
2317+
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2318+
IntrinsicInst *II0 = dyn_cast<IntrinsicInst>(Inst.getOperand(0));
2319+
if (!II0)
2320+
return nullptr;
2321+
IntrinsicInst *II1 = dyn_cast<IntrinsicInst>(Inst.getOperand(1));
2322+
if (!II1)
2323+
return nullptr;
2324+
2325+
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2326+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2327+
if (BinOpOpc == Instruction::Sub)
2328+
ReductionIID = Intrinsic::vector_reduce_add;
2329+
2330+
if (ReductionIID == Intrinsic::num_intrinsics)
2331+
return nullptr;
2332+
if (II0->getIntrinsicID() != ReductionIID)
2333+
return nullptr;
2334+
if (II1->getIntrinsicID() != ReductionIID)
2335+
return nullptr;
2336+
2337+
Value *V0 = II0->getArgOperand(0);
2338+
Value *V1 = II1->getArgOperand(0);
2339+
Type *VTy = V0->getType();
2340+
if (V1->getType() != VTy)
2341+
return nullptr;
2342+
2343+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2344+
// if (auto *VectorInstBO = dyn_cast<BinaryOperator>(VectorBO))
2345+
// VectorInstBO->copyIRFlags(&Inst);
2346+
2347+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2348+
return Rdx;
2349+
}
2350+
22992351
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23002352
/// of a value. This requires a potentially expensive known bits check to make
23012353
/// sure the narrow op does not overflow.

0 commit comments

Comments
 (0)