Skip to content

Commit 9753c8d

Browse files
author
Mikhail Gudim
committed
[InstCombine] Fold binary op of reductions.
Replace binary of of two reductions with one reduction of the binary op applied to vectors. For example: ``` %v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0) %v1_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v1) %res = add i32 %v0_red, %v1_red ``` gets transformed to: ``` %1 = add <16 x i32> %v0, %v1 %res = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %1) ```
1 parent 81a8b20 commit 9753c8d

File tree

6 files changed

+91
-37
lines changed

6 files changed

+91
-37
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15281528
if (Instruction *X = foldVectorBinop(I))
15291529
return X;
15301530

1531+
if (Instruction *X = foldBinopOfReductions(I))
1532+
return replaceInstUsesWith(I, X);
1533+
15311534
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15321535
return Phi;
15331536

@@ -2387,19 +2390,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23872390
}
23882391
}
23892392

2390-
auto m_AddRdx = [](Value *&Vec) {
2391-
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2392-
};
2393-
Value *V0, *V1;
2394-
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2395-
V0->getType() == V1->getType()) {
2396-
// Difference of sums is sum of differences:
2397-
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2398-
Value *Sub = Builder.CreateSub(V0, V1);
2399-
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2400-
{Sub->getType()}, {Sub});
2401-
return replaceInstUsesWith(I, Rdx);
2402-
}
2393+
if (Instruction *X = foldBinopOfReductions(I))
2394+
return replaceInstUsesWith(I, X);
24032395

24042396
if (Constant *C = dyn_cast<Constant>(Op0)) {
24052397
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23852385
if (Instruction *X = foldVectorBinop(I))
23862386
return X;
23872387

2388+
if (Instruction *X = foldBinopOfReductions(I))
2389+
return replaceInstUsesWith(I, X);
2390+
23882391
if (Instruction *Phi = foldBinopWithPhiOperands(I))
23892392
return Phi;
23902393

@@ -3565,6 +3568,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35653568
if (Instruction *X = foldVectorBinop(I))
35663569
return X;
35673570

3571+
if (Instruction *X = foldBinopOfReductions(I))
3572+
return replaceInstUsesWith(I, X);
3573+
35683574
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35693575
return Phi;
35703576

@@ -4688,6 +4694,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46884694
if (Instruction *X = foldVectorBinop(I))
46894695
return X;
46904696

4697+
if (Instruction *X = foldBinopOfReductions(I))
4698+
return replaceInstUsesWith(I, X);
4699+
46914700
if (Instruction *Phi = foldBinopWithPhiOperands(I))
46924701
return Phi;
46934702

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
594594

595595
/// Canonicalize the position of binops relative to shufflevector.
596596
Instruction *foldVectorBinop(BinaryOperator &Inst);
597+
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
597598
Instruction *foldVectorSelect(SelectInst &Sel);
598599
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
599600

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
199199
if (Instruction *X = foldVectorBinop(I))
200200
return X;
201201

202+
if (Instruction *X = foldBinopOfReductions(I))
203+
return replaceInstUsesWith(I, X);
204+
202205
if (Instruction *Phi = foldBinopWithPhiOperands(I))
203206
return Phi;
204207

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,63 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23182318
return nullptr;
23192319
}
23202320

2321+
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2322+
switch (Opc) {
2323+
default:
2324+
break;
2325+
case Instruction::Add:
2326+
return Intrinsic::vector_reduce_add;
2327+
case Instruction::Mul:
2328+
return Intrinsic::vector_reduce_mul;
2329+
case Instruction::And:
2330+
return Intrinsic::vector_reduce_and;
2331+
case Instruction::Or:
2332+
return Intrinsic::vector_reduce_or;
2333+
case Instruction::Xor:
2334+
return Intrinsic::vector_reduce_xor;
2335+
}
2336+
return Intrinsic::not_intrinsic;
2337+
}
2338+
2339+
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2340+
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2341+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2342+
if (BinOpOpc == Instruction::Sub)
2343+
ReductionIID = Intrinsic::vector_reduce_add;
2344+
if (ReductionIID == Intrinsic::not_intrinsic)
2345+
return nullptr;
2346+
2347+
auto checkIntrinsicAndGetItsArgument = [](Value *V,
2348+
Intrinsic::ID IID) -> Value * {
2349+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
2350+
if (!II)
2351+
return nullptr;
2352+
if (II->getIntrinsicID() == IID && II->hasOneUse())
2353+
return II->getArgOperand(0);
2354+
return nullptr;
2355+
};
2356+
2357+
Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
2358+
if (!V0)
2359+
return nullptr;
2360+
Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
2361+
if (!V1)
2362+
return nullptr;
2363+
2364+
Type *VTy = V0->getType();
2365+
if (V1->getType() != VTy)
2366+
return nullptr;
2367+
2368+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2369+
2370+
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
2371+
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
2372+
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
2373+
2374+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2375+
return Rdx;
2376+
}
2377+
23212378
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23222379
/// of a value. This requires a potentially expensive known bits check to make
23232380
/// sure the narrow op does not overflow.

llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
55
; CHECK-LABEL: define i32 @add_of_reduce_add(
66
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
7-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
8-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
9-
; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
7+
; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
8+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
109
; CHECK-NEXT: ret i32 [[RES]]
1110
;
1211
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -31,9 +30,8 @@ define i32 @sub_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
3130
define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
3231
; CHECK-LABEL: define i32 @mul_of_reduce_mul(
3332
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
34-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V0]])
35-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[V1]])
36-
; CHECK-NEXT: [[RES:%.*]] = mul i32 [[V0_RED]], [[V1_RED]]
33+
; CHECK-NEXT: [[TMP1:%.*]] = mul <16 x i32> [[V0]], [[V1]]
34+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> [[TMP1]])
3735
; CHECK-NEXT: ret i32 [[RES]]
3836
;
3937
%v0_red = tail call i32 @llvm.vector.reduce.mul.v16i32(<16 x i32> %v0)
@@ -45,9 +43,8 @@ define i32 @mul_of_reduce_mul(<16 x i32> %v0, <16 x i32> %v1) {
4543
define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
4644
; CHECK-LABEL: define i32 @and_of_reduce_and(
4745
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
48-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V0]])
49-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[V1]])
50-
; CHECK-NEXT: [[RES:%.*]] = and i32 [[V0_RED]], [[V1_RED]]
46+
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i32> [[V0]], [[V1]]
47+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> [[TMP1]])
5148
; CHECK-NEXT: ret i32 [[RES]]
5249
;
5350
%v0_red = tail call i32 @llvm.vector.reduce.and.v16i32(<16 x i32> %v0)
@@ -59,9 +56,8 @@ define i32 @and_of_reduce_and(<16 x i32> %v0, <16 x i32> %v1) {
5956
define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
6057
; CHECK-LABEL: define i32 @or_of_reduce_or(
6158
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
62-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
63-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
64-
; CHECK-NEXT: [[RES:%.*]] = or i32 [[V0_RED]], [[V1_RED]]
59+
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i32> [[V0]], [[V1]]
60+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
6561
; CHECK-NEXT: ret i32 [[RES]]
6662
;
6763
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -73,9 +69,8 @@ define i32 @or_of_reduce_or(<16 x i32> %v0, <16 x i32> %v1) {
7369
define i32 @xor_of_reduce_xor(<16 x i32> %v0, <16 x i32> %v1) {
7470
; CHECK-LABEL: define i32 @xor_of_reduce_xor(
7571
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
76-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V0]])
77-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[V1]])
78-
; CHECK-NEXT: [[RES:%.*]] = xor i32 [[V0_RED]], [[V1_RED]]
72+
; CHECK-NEXT: [[TMP1:%.*]] = xor <16 x i32> [[V0]], [[V1]]
73+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> [[TMP1]])
7974
; CHECK-NEXT: ret i32 [[RES]]
8075
;
8176
%v0_red = tail call i32 @llvm.vector.reduce.xor.v16i32(<16 x i32> %v0)
@@ -161,9 +156,8 @@ define i32 @multiple_use_of_reduction_1(<16 x i32> %v0, <16 x i32> %v1, ptr %p)
161156
define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
162157
; CHECK-LABEL: define i32 @do_not_preserve_overflow_flags(
163158
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
164-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V0]])
165-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[V1]])
166-
; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[V0_RED]], [[V1_RED]]
159+
; CHECK-NEXT: [[TMP1:%.*]] = add <16 x i32> [[V0]], [[V1]]
160+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP1]])
167161
; CHECK-NEXT: ret i32 [[RES]]
168162
;
169163
%v0_red = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %v0)
@@ -175,9 +169,8 @@ define i32 @do_not_preserve_overflow_flags(<16 x i32> %v0, <16 x i32> %v1) {
175169
define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
176170
; CHECK-LABEL: define i32 @preserve_disjoint_flags(
177171
; CHECK-SAME: <16 x i32> [[V0:%.*]], <16 x i32> [[V1:%.*]]) {
178-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V0]])
179-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[V1]])
180-
; CHECK-NEXT: [[RES:%.*]] = or disjoint i32 [[V0_RED]], [[V1_RED]]
172+
; CHECK-NEXT: [[TMP1:%.*]] = or disjoint <16 x i32> [[V0]], [[V1]]
173+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> [[TMP1]])
181174
; CHECK-NEXT: ret i32 [[RES]]
182175
;
183176
%v0_red = tail call i32 @llvm.vector.reduce.or.v16i32(<16 x i32> %v0)
@@ -189,9 +182,8 @@ define i32 @preserve_disjoint_flags(<16 x i32> %v0, <16 x i32> %v1) {
189182
define i32 @add_of_reduce_add_vscale(<vscale x 16 x i32> %v0, <vscale x 16 x i32> %v1) {
190183
; CHECK-LABEL: define i32 @add_of_reduce_add_vscale(
191184
; CHECK-SAME: <vscale x 16 x i32> [[V0:%.*]], <vscale x 16 x i32> [[V1:%.*]]) {
192-
; CHECK-NEXT: [[V0_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V0]])
193-
; CHECK-NEXT: [[V1_RED:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[V1]])
194-
; CHECK-NEXT: [[RES:%.*]] = add i32 [[V0_RED]], [[V1_RED]]
185+
; CHECK-NEXT: [[TMP1:%.*]] = add <vscale x 16 x i32> [[V0]], [[V1]]
186+
; CHECK-NEXT: [[RES:%.*]] = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> [[TMP1]])
195187
; CHECK-NEXT: ret i32 [[RES]]
196188
;
197189
%v0_red = tail call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %v0)

0 commit comments

Comments
 (0)