Skip to content

Commit d4de22e

Browse files
author
Mikhail Gudim
committed
[InstCombine] Fold binary op of reductions.
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example: ``` %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0) %v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1) %res = add i32 %v0_red, %v1_red ``` gets transformed to: ``` %1 = add <16 x i32> %v0, %v1 %res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) ```
1 parent beba4b0 commit d4de22e

File tree

6 files changed

+91
-37
lines changed

6 files changed

+91
-37
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15281528
if (Instruction *X = foldVectorBinop(I))
15291529
return X;
15301530

1531+
if (Instruction *X = foldBinopOfReductions(I))
1532+
return replaceInstUsesWith(I, X);
1533+
15311534
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15321535
return Phi;
15331536

@@ -2378,19 +2381,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23782381
}
23792382
}
23802383

2381-
auto m_AddRdx = [](Value *&Vec) {
2382-
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2383-
};
2384-
Value *V0, *V1;
2385-
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2386-
V0->getType() == V1->getType()) {
2387-
// Difference of sums is sum of differences:
2388-
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2389-
Value *Sub = Builder.CreateSub(V0, V1);
2390-
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2391-
{Sub->getType()}, {Sub});
2392-
return replaceInstUsesWith(I, Rdx);
2393-
}
2384+
if (Instruction *X = foldBinopOfReductions(I))
2385+
return replaceInstUsesWith(I, X);
23942386

23952387
if (Constant *C = dyn_cast<Constant>(Op0)) {
23962388
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23722372
if (Instruction *X = foldVectorBinop(I))
23732373
return X;
23742374

2375+
if (Instruction *X = foldBinopOfReductions(I))
2376+
return replaceInstUsesWith(I, X);
2377+
23752378
if (Instruction *Phi = foldBinopWithPhiOperands(I))
23762379
return Phi;
23772380

@@ -3552,6 +3555,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35523555
if (Instruction *X = foldVectorBinop(I))
35533556
return X;
35543557

3558+
if (Instruction *X = foldBinopOfReductions(I))
3559+
return replaceInstUsesWith(I, X);
3560+
35553561
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35563562
return Phi;
35573563

@@ -4663,6 +4669,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46634669
if (Instruction *X = foldVectorBinop(I))
46644670
return X;
46654671

4672+
if (Instruction *X = foldBinopOfReductions(I))
4673+
return replaceInstUsesWith(I, X);
4674+
46664675
if (Instruction *Phi = foldBinopWithPhiOperands(I))
46674676
return Phi;
46684677

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
594594

595595
/// Canonicalize the position of binops relative to shufflevector.
596596
Instruction *foldVectorBinop(BinaryOperator &Inst);
597+
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
597598
Instruction *foldVectorSelect(SelectInst &Sel);
598599
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
599600

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
201201
if (Instruction *X = foldVectorBinop(I))
202202
return X;
203203

204+
if (Instruction *X = foldBinopOfReductions(I))
205+
return replaceInstUsesWith(I, X);
206+
204207
if (Instruction *Phi = foldBinopWithPhiOperands(I))
205208
return Phi;
206209

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,63 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23122312
return nullptr;
23132313
}
23142314

2315+
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2316+
switch (Opc) {
2317+
default:
2318+
break;
2319+
case Instruction::Add:
2320+
return Intrinsic::vector_reduce_add;
2321+
case Instruction::Mul:
2322+
return Intrinsic::vector_reduce_mul;
2323+
case Instruction::And:
2324+
return Intrinsic::vector_reduce_and;
2325+
case Instruction::Or:
2326+
return Intrinsic::vector_reduce_or;
2327+
case Instruction::Xor:
2328+
return Intrinsic::vector_reduce_xor;
2329+
}
2330+
return Intrinsic::not_intrinsic;
2331+
}
2332+
2333+
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2334+
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2335+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2336+
if (BinOpOpc == Instruction::Sub)
2337+
ReductionIID = Intrinsic::vector_reduce_add;
2338+
if (ReductionIID == Intrinsic::not_intrinsic)
2339+
return nullptr;
2340+
2341+
auto checkIntrinsicAndGetItsArgument = [](Value *V,
2342+
Intrinsic::ID IID) -> Value * {
2343+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
2344+
if (!II)
2345+
return nullptr;
2346+
if (II->getIntrinsicID() == IID && II->hasOneUse())
2347+
return II->getArgOperand(0);
2348+
return nullptr;
2349+
};
2350+
2351+
Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
2352+
if (!V0)
2353+
return nullptr;
2354+
Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
2355+
if (!V1)
2356+
return nullptr;
2357+
2358+
Type *VTy = V0->getType();
2359+
if (V1->getType() != VTy)
2360+
return nullptr;
2361+
2362+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2363+
2364+
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
2365+
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
2366+
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
2367+
2368+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2369+
return Rdx;
2370+
}
2371+
23152372
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23162373
/// of a value. This requires a potentially expensive known bits check to make
23172374
/// sure the narrow op does not overflow.

llvm/test/Transforms/InstCombine/fold-binop-of-reductions.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
55
; CHECK-LABEL: define i32 @add_of_reduce_add(
66
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
7-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
8-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
9-
; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
7+
; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
8+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
109
; CHECK-NEXT: ret i32 [[RES]]
1110
;
1211
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -31,9 +30,8 @@ define i32 @sub_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
3130
define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
3231
; CHECK-LABEL: define i32 @mul_of_reduce_mul(
3332
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
34-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V0]])
35-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V1]])
36-
; CHECK-NEXT: [[RES:%.*]] = mul i32 [[V0_RED]], [[V1_RED]]
33+
; CHECK-NEXT: [[TMP1:%.*]] = mul <16 x i32> [[V0]], [[V1]]
34+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[TMP1]])
3735
; CHECK-NEXT: ret i32 [[RES]]
3836
;
3937
%v0_red = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> %v0)
@@ -45,9 +43,8 @@ define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
4543
define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
4644
; CHECK-LABEL: define i32 @and_of_reduce_and(
4745
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
48-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V0]])
49-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V1]])
50-
; CHECK-NEXT: [[RES:%.*]] = and i32 [[V0_RED]], [[V1_RED]]
46+
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i32> [[V0]], [[V1]]
47+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP1]])
5148
; CHECK-NEXT: ret i32 [[RES]]
5249
;
5350
%v0_red = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> %v0)
@@ -59,9 +56,8 @@ define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
5956
define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
6057
; CHECK-LABEL: define i32 @or_of_reduce_or(
6158
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
62-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
63-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
64-
; CHECK-NEXT: [[RES:%.*]] = or i32 [[V0_RED]], [[V1_RED]]
59+
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i32> [[V0]], [[V1]]
60+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
6561
; CHECK-NEXT: ret i32 [[RES]]
6662
;
6763
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -73,9 +69,8 @@ define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
7369
define i32 @xor_of_reduce_xor(<16 x i32> %v0, <16 x i32> %v1) {
7470
; CHECK-LABEL: define i32 @xor_of_reduce_xor(
7571
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
76-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V0]])
77-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V1]])
78-
; CHECK-NEXT: [[RES:%.*]] = xor i32 [[V0_RED]], [[V1_RED]]
72+
; CHECK-NEXT: [[TMP1:%.*]] = xor <16 x i32> [[V0]], [[V1]]
73+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[TMP1]])
7974
; CHECK-NEXT: ret i32 [[RES]]
8075
;
8176
%v0_red = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> %v0)
@@ -161,9 +156,8 @@ define i32 @multiple_use_of_reduction_1(<16 x i32> %v0, <16 x i32> %v1, ptr %p)
161156
define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
162157
; CHECK-LABEL: define i32 @do_not_preserve_overflow_flags(
163158
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
164-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
165-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
166-
; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[V0_RED]], [[V1_RED]]
159+
; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
160+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
167161
; CHECK-NEXT: ret i32 [[RES]]
168162
;
169163
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -175,9 +169,8 @@ define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
175169
define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
176170
; CHECK-LABEL: define i32 @preserve_disjoint_flags(
177171
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
178-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
179-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
180-
; CHECK-NEXT: [[RES:%.*]] = or disjoint i32 [[V0_RED]], [[V1_RED]]
172+
; CHECK-NEXT: [[TMP1:%.*]] = or disjoint <16 x i32> [[V0]], [[V1]]
173+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
181174
; CHECK-NEXT: ret i32 [[RES]]
182175
;
183176
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -189,9 +182,8 @@ define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
189182
define i32 @add_of_reduce_add_vscale(<vscale x 16 x i32> %v0, <vscale x 16 x i32> %v1) {
190183
; CHECK-LABEL: define i32 @add_of_reduce_add_vscale(
191184
; CHECK-SAME: <vscale x 16 x i32> [[V0:%.*]], <vscale x 16 x i32> [[V1:%.*]]) {
192-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V0]])
193-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V1]])
194-
; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
185+
; CHECK-NEXT: [[TMP1:%.*]] = add <vscale x 16 x i32> [[V0]], [[V1]]
186+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP1]])
195187
; CHECK-NEXT: ret i32 [[RES]]
196188
;
197189
%v0_red = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %v0)

0 commit comments

Comments
 (0)