Skip to content

Commit 86ae393

Browse files
author
Mikhail Gudim
committed
moved to vectorcombine
1 parent a569a5f commit 86ae393

File tree

9 files changed

+95
-76
lines changed

9 files changed

+95
-76
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
365365

366366
/// Returns the arithmetic instruction opcode used when expanding a reduction.
367367
unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
368+
/// Returns the reduction intrinsic id corresponding to the binary operation.
369+
Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc);
368370

369371
/// Returns the min/max intrinsic used when expanding a min/max reduction.
370372
Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,9 +1528,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15281528
if (Instruction *X = foldVectorBinop(I))
15291529
return X;
15301530

1531-
if (Instruction *X = foldBinopOfReductions(I))
1532-
return replaceInstUsesWith(I, X);
1533-
15341531
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15351532
return Phi;
15361533

@@ -2390,8 +2387,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23902387
}
23912388
}
23922389

2393-
if (Instruction *X = foldBinopOfReductions(I))
2394-
return replaceInstUsesWith(I, X);
2390+
auto m_AddRdx = [](Value *&Vec) {
2391+
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2392+
};
2393+
Value *V0, *V1;
2394+
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2395+
V0->getType() == V1->getType()) {
2396+
// Difference of sums is sum of differences:
2397+
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2398+
Value *Sub = Builder.CreateSub(V0, V1);
2399+
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2400+
{Sub->getType()}, {Sub});
2401+
return replaceInstUsesWith(I, Rdx);
2402+
}
23952403

23962404
if (Constant *C = dyn_cast<Constant>(Op0)) {
23972405
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,9 +2380,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23802380
if (Instruction *X = foldVectorBinop(I))
23812381
return X;
23822382

2383-
if (Instruction *X = foldBinopOfReductions(I))
2384-
return replaceInstUsesWith(I, X);
2385-
23862383
if (Instruction *Phi = foldBinopWithPhiOperands(I))
23872384
return Phi;
23882385

@@ -3563,9 +3560,6 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35633560
if (Instruction *X = foldVectorBinop(I))
35643561
return X;
35653562

3566-
if (Instruction *X = foldBinopOfReductions(I))
3567-
return replaceInstUsesWith(I, X);
3568-
35693563
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35703564
return Phi;
35713565

@@ -4677,9 +4671,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46774671
if (Instruction *X = foldVectorBinop(I))
46784672
return X;
46794673

4680-
if (Instruction *X = foldBinopOfReductions(I))
4681-
return replaceInstUsesWith(I, X);
4682-
46834674
if (Instruction *Phi = foldBinopWithPhiOperands(I))
46844675
return Phi;
46854676

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
594594

595595
/// Canonicalize the position of binops relative to shufflevector.
596596
Instruction *foldVectorBinop(BinaryOperator &Inst);
597-
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
598597
Instruction *foldVectorSelect(SelectInst &Sel);
599598
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
600599

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
199199
if (Instruction *X = foldVectorBinop(I))
200200
return X;
201201

202-
if (Instruction *X = foldBinopOfReductions(I))
203-
return replaceInstUsesWith(I, X);
204-
205202
if (Instruction *Phi = foldBinopWithPhiOperands(I))
206203
return Phi;
207204

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,63 +2317,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23172317
return nullptr;
23182318
}
23192319

2320-
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2321-
switch (Opc) {
2322-
default:
2323-
break;
2324-
case Instruction::Add:
2325-
return Intrinsic::vector_reduce_add;
2326-
case Instruction::Mul:
2327-
return Intrinsic::vector_reduce_mul;
2328-
case Instruction::And:
2329-
return Intrinsic::vector_reduce_and;
2330-
case Instruction::Or:
2331-
return Intrinsic::vector_reduce_or;
2332-
case Instruction::Xor:
2333-
return Intrinsic::vector_reduce_xor;
2334-
}
2335-
return Intrinsic::not_intrinsic;
2336-
}
2337-
2338-
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2339-
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2340-
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2341-
if (BinOpOpc == Instruction::Sub)
2342-
ReductionIID = Intrinsic::vector_reduce_add;
2343-
if (ReductionIID == Intrinsic::not_intrinsic)
2344-
return nullptr;
2345-
2346-
auto checkIntrinsicAndGetItsArgument = [](Value *V,
2347-
Intrinsic::ID IID) -> Value * {
2348-
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
2349-
if (!II)
2350-
return nullptr;
2351-
if (II->getIntrinsicID() == IID && II->hasOneUse())
2352-
return II->getArgOperand(0);
2353-
return nullptr;
2354-
};
2355-
2356-
Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
2357-
if (!V0)
2358-
return nullptr;
2359-
Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
2360-
if (!V1)
2361-
return nullptr;
2362-
2363-
Type *VTy = V0->getType();
2364-
if (V1->getType() != VTy)
2365-
return nullptr;
2366-
2367-
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2368-
2369-
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
2370-
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
2371-
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
2372-
2373-
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2374-
return Rdx;
2375-
}
2376-
23772320
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23782321
/// of a value. This requires a potentially expensive known bits check to make
23792322
/// sure the narrow op does not overflow.

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
957957
}
958958
}
959959

960+
// This is the inverse to getReductionForBinop
960961
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
961962
switch (RdxID) {
962963
case Intrinsic::vector_reduce_fadd:
@@ -986,6 +987,25 @@ unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
986987
}
987988
}
988989

990+
// This is the inverse to getArithmeticReductionInstruction
991+
Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
992+
switch (Opc) {
993+
default:
994+
break;
995+
case Instruction::Add:
996+
return Intrinsic::vector_reduce_add;
997+
case Instruction::Mul:
998+
return Intrinsic::vector_reduce_mul;
999+
case Instruction::And:
1000+
return Intrinsic::vector_reduce_and;
1001+
case Instruction::Or:
1002+
return Intrinsic::vector_reduce_or;
1003+
case Instruction::Xor:
1004+
return Intrinsic::vector_reduce_xor;
1005+
}
1006+
return Intrinsic::not_intrinsic;
1007+
}
1008+
9891009
Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
9901010
switch (RdxID) {
9911011
default:

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ static cl::opt<unsigned> MaxInstrsToScan(
6060
"vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
6161
cl::desc("Max number of instructions to scan for vector combining."));
6262

63+
static cl::opt<bool> ForceFoldBinopOfReductions(
64+
"vector-combine-force-fold-binop-of-reductions", cl::init(false),
65+
cl::Hidden,
66+
cl::desc("Force folding binary of of two reductions even if it is not "
67+
"profitable. Should be used for testing only."));
68+
6369
static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
6470

6571
namespace {
@@ -113,6 +119,7 @@ class VectorCombine {
113119
bool scalarizeBinopOrCmp(Instruction &I);
114120
bool scalarizeVPIntrinsic(Instruction &I);
115121
bool foldExtractedCmps(Instruction &I);
122+
bool foldBinopOfReductions(Instruction &I);
116123
bool foldSingleElementStore(Instruction &I);
117124
bool scalarizeLoadExtract(Instruction &I);
118125
bool foldConcatOfBoolMasks(Instruction &I);
@@ -1182,6 +1189,57 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
11821189
return true;
11831190
}
11841191

1192+
bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1193+
Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1194+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1195+
if (BinOpOpc == Instruction::Sub)
1196+
ReductionIID = Intrinsic::vector_reduce_add;
1197+
if (ReductionIID == Intrinsic::not_intrinsic)
1198+
return false;
1199+
1200+
auto checkIntrinsicAndGetItsArgument = [](Value *V,
1201+
Intrinsic::ID IID) -> Value * {
1202+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
1203+
if (!II)
1204+
return nullptr;
1205+
if (II->getIntrinsicID() == IID && II->hasOneUse())
1206+
return II->getArgOperand(0);
1207+
return nullptr;
1208+
};
1209+
1210+
Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1211+
if (!V0)
1212+
return false;
1213+
Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1214+
if (!V1)
1215+
return false;
1216+
1217+
VectorType *VTy = cast<VectorType>(V0->getType());
1218+
if (V1->getType() != VTy)
1219+
return false;
1220+
1221+
InstructionCost ReductionCost =
1222+
TTI.getArithmeticReductionCost(BinOpOpc, VTy, std::nullopt, CostKind);
1223+
InstructionCost OldCost =
1224+
2 * ReductionCost + TTI.getInstructionCost(&I, CostKind);
1225+
InstructionCost NewCost =
1226+
ReductionCost + TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind);
1227+
if (NewCost >= OldCost && !ForceFoldBinopOfReductions)
1228+
return false;
1229+
1230+
LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1231+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1232+
<< "\n");
1233+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1234+
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&I))
1235+
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
1236+
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
1237+
1238+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1239+
replaceValue(I, *Rdx);
1240+
return true;
1241+
}
1242+
11851243
// Check if memory loc modified between two instrs in the same BB
11861244
static bool isMemModifiedBetween(BasicBlock::iterator Begin,
11871245
BasicBlock::iterator End,
@@ -3241,6 +3299,7 @@ bool VectorCombine::run() {
32413299
if (Instruction::isBinaryOp(Opcode)) {
32423300
MadeChange |= foldExtractExtract(I);
32433301
MadeChange |= foldExtractedCmps(I);
3302+
MadeChange |= foldBinopOfReductions(I);
32443303
}
32453304
break;
32463305
}

llvm/test/Transforms/VectorCombine/fold-binop-of-reductions.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2-
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
2+
; RUN: opt < %s -passes=vector-combine -vector-combine-force-fold-binop-of-reductions=true -S | FileCheck %s
33

44
define i32 @add_of_reduce_add(<16 x i32> %v0, <16 x i32> %v1) {
55
; CHECK-LABEL: define i32 @add_of_reduce_add(

0 commit comments

Comments
 (0)