Skip to content

Commit bc1a198

Browse files
author
Mikhail Gudim
committed
Improve cost model
1 parent 9753c8d commit bc1a198

File tree

10 files changed

+246
-76
lines changed

10 files changed

+246
-76
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);
365365

366366
/// Returns the arithmetic instruction opcode used when expanding a reduction.
367367
unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);
368+
/// Returns the reduction intrinsic id corresponding to the binary operation.
369+
Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc);
368370

369371
/// Returns the min/max intrinsic used when expanding a min/max reduction.
370372
Intrinsic::ID getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID);

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,9 +1528,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
15281528
if (Instruction *X = foldVectorBinop(I))
15291529
return X;
15301530

1531-
if (Instruction *X = foldBinopOfReductions(I))
1532-
return replaceInstUsesWith(I, X);
1533-
15341531
if (Instruction *Phi = foldBinopWithPhiOperands(I))
15351532
return Phi;
15361533

@@ -2390,8 +2387,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23902387
}
23912388
}
23922389

2393-
if (Instruction *X = foldBinopOfReductions(I))
2394-
return replaceInstUsesWith(I, X);
2390+
auto m_AddRdx = [](Value *&Vec) {
2391+
return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
2392+
};
2393+
Value *V0, *V1;
2394+
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
2395+
V0->getType() == V1->getType()) {
2396+
// Difference of sums is sum of differences:
2397+
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
2398+
Value *Sub = Builder.CreateSub(V0, V1);
2399+
Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
2400+
{Sub->getType()}, {Sub});
2401+
return replaceInstUsesWith(I, Rdx);
2402+
}
23952403

23962404
if (Constant *C = dyn_cast<Constant>(Op0)) {
23972405
Value *X;

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,9 +2385,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
23852385
if (Instruction *X = foldVectorBinop(I))
23862386
return X;
23872387

2388-
if (Instruction *X = foldBinopOfReductions(I))
2389-
return replaceInstUsesWith(I, X);
2390-
23912388
if (Instruction *Phi = foldBinopWithPhiOperands(I))
23922389
return Phi;
23932390

@@ -3568,9 +3565,6 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
35683565
if (Instruction *X = foldVectorBinop(I))
35693566
return X;
35703567

3571-
if (Instruction *X = foldBinopOfReductions(I))
3572-
return replaceInstUsesWith(I, X);
3573-
35743568
if (Instruction *Phi = foldBinopWithPhiOperands(I))
35753569
return Phi;
35763570

@@ -4694,9 +4688,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46944688
if (Instruction *X = foldVectorBinop(I))
46954689
return X;
46964690

4697-
if (Instruction *X = foldBinopOfReductions(I))
4698-
return replaceInstUsesWith(I, X);
4699-
47004691
if (Instruction *Phi = foldBinopWithPhiOperands(I))
47014692
return Phi;
47024693

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
594594

595595
/// Canonicalize the position of binops relative to shufflevector.
596596
Instruction *foldVectorBinop(BinaryOperator &Inst);
597-
Instruction *foldBinopOfReductions(BinaryOperator &Inst);
598597
Instruction *foldVectorSelect(SelectInst &Sel);
599598
Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf);
600599

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
199199
if (Instruction *X = foldVectorBinop(I))
200200
return X;
201201

202-
if (Instruction *X = foldBinopOfReductions(I))
203-
return replaceInstUsesWith(I, X);
204-
205202
if (Instruction *Phi = foldBinopWithPhiOperands(I))
206203
return Phi;
207204

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,63 +2318,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
23182318
return nullptr;
23192319
}
23202320

2321-
static Intrinsic::ID getReductionForBinop(Instruction::BinaryOps Opc) {
2322-
switch (Opc) {
2323-
default:
2324-
break;
2325-
case Instruction::Add:
2326-
return Intrinsic::vector_reduce_add;
2327-
case Instruction::Mul:
2328-
return Intrinsic::vector_reduce_mul;
2329-
case Instruction::And:
2330-
return Intrinsic::vector_reduce_and;
2331-
case Instruction::Or:
2332-
return Intrinsic::vector_reduce_or;
2333-
case Instruction::Xor:
2334-
return Intrinsic::vector_reduce_xor;
2335-
}
2336-
return Intrinsic::not_intrinsic;
2337-
}
2338-
2339-
Instruction *InstCombinerImpl::foldBinopOfReductions(BinaryOperator &Inst) {
2340-
Instruction::BinaryOps BinOpOpc = Inst.getOpcode();
2341-
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
2342-
if (BinOpOpc == Instruction::Sub)
2343-
ReductionIID = Intrinsic::vector_reduce_add;
2344-
if (ReductionIID == Intrinsic::not_intrinsic)
2345-
return nullptr;
2346-
2347-
auto checkIntrinsicAndGetItsArgument = [](Value *V,
2348-
Intrinsic::ID IID) -> Value * {
2349-
IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
2350-
if (!II)
2351-
return nullptr;
2352-
if (II->getIntrinsicID() == IID && II->hasOneUse())
2353-
return II->getArgOperand(0);
2354-
return nullptr;
2355-
};
2356-
2357-
Value *V0 = checkIntrinsicAndGetItsArgument(Inst.getOperand(0), ReductionIID);
2358-
if (!V0)
2359-
return nullptr;
2360-
Value *V1 = checkIntrinsicAndGetItsArgument(Inst.getOperand(1), ReductionIID);
2361-
if (!V1)
2362-
return nullptr;
2363-
2364-
Type *VTy = V0->getType();
2365-
if (V1->getType() != VTy)
2366-
return nullptr;
2367-
2368-
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
2369-
2370-
if (PossiblyDisjointInst *PDInst = dyn_cast<PossiblyDisjointInst>(&Inst))
2371-
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
2372-
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
2373-
2374-
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
2375-
return Rdx;
2376-
}
2377-
23782321
/// Try to narrow the width of a binop if at least 1 operand is an extend of
23792322
/// of a value. This requires a potentially expensive known bits check to make
23802323
/// sure the narrow op does not overflow.

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
957957
}
958958
}
959959

960+
// This is the inverse to getReductionForBinop
960961
unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
961962
switch (RdxID) {
962963
case Intrinsic::vector_reduce_fadd:
@@ -986,6 +987,25 @@ unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
986987
}
987988
}
988989

990+
// This is the inverse to getArithmeticReductionInstruction
991+
Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
992+
switch (Opc) {
993+
default:
994+
break;
995+
case Instruction::Add:
996+
return Intrinsic::vector_reduce_add;
997+
case Instruction::Mul:
998+
return Intrinsic::vector_reduce_mul;
999+
case Instruction::And:
1000+
return Intrinsic::vector_reduce_and;
1001+
case Instruction::Or:
1002+
return Intrinsic::vector_reduce_or;
1003+
case Instruction::Xor:
1004+
return Intrinsic::vector_reduce_xor;
1005+
}
1006+
return Intrinsic::not_intrinsic;
1007+
}
1008+
9891009
Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
9901010
switch (RdxID) {
9911011
default:

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class VectorCombine {
114114
bool scalarizeBinopOrCmp(Instruction &I);
115115
bool scalarizeVPIntrinsic(Instruction &I);
116116
bool foldExtractedCmps(Instruction &I);
117+
bool foldBinopOfReductions(Instruction &I);
117118
bool foldSingleElementStore(Instruction &I);
118119
bool scalarizeLoadExtract(Instruction &I);
119120
bool foldConcatOfBoolMasks(Instruction &I);
@@ -1242,6 +1243,121 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
12421243
return true;
12431244
}
12441245

1246+
static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1247+
TTI::TargetCostKind CostKind,
1248+
const TargetTransformInfo &TTI,
1249+
InstructionCost &CostBeforeReduction,
1250+
InstructionCost &CostAfterReduction) {
1251+
Instruction *Op0, *Op1;
1252+
auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
1253+
auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
1254+
unsigned ReductionOpc =
1255+
getArithmeticReductionInstruction(II.getIntrinsicID());
1256+
if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
1257+
bool IsUnsigned = isa<ZExtInst>(RedOp);
1258+
auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
1259+
1260+
CostBeforeReduction =
1261+
TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
1262+
TTI::CastContextHint::None, CostKind, RedOp);
1263+
CostAfterReduction =
1264+
TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
1265+
ExtType, FastMathFlags(), CostKind);
1266+
return;
1267+
}
1268+
if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1269+
match(RedOp,
1270+
m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
1271+
match(Op0, m_ZExtOrSExt(m_Value())) &&
1272+
Op0->getOpcode() == Op1->getOpcode() &&
1273+
Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
1274+
(Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1275+
// Matched reduce.add(ext(mul(ext(A), ext(B)))
1276+
bool IsUnsigned = isa<ZExtInst>(Op0);
1277+
auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
1278+
VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
1279+
1280+
InstructionCost ExtCost =
1281+
TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
1282+
TTI::CastContextHint::None, CostKind, Op0);
1283+
InstructionCost MulCost =
1284+
TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
1285+
InstructionCost Ext2Cost =
1286+
TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
1287+
TTI::CastContextHint::None, CostKind, RedOp);
1288+
1289+
CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1290+
CostAfterReduction =
1291+
TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
1292+
return;
1293+
}
1294+
CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
1295+
std::nullopt, CostKind);
1296+
return;
1297+
}
1298+
1299+
bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1300+
Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1301+
Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1302+
if (BinOpOpc == Instruction::Sub)
1303+
ReductionIID = Intrinsic::vector_reduce_add;
1304+
if (ReductionIID == Intrinsic::not_intrinsic)
1305+
return false;
1306+
1307+
auto checkIntrinsicAndGetItsArgument = [](Value *V,
1308+
Intrinsic::ID IID) -> Value * {
1309+
auto *II = dyn_cast<IntrinsicInst>(V);
1310+
if (!II)
1311+
return nullptr;
1312+
if (II->getIntrinsicID() == IID && II->hasOneUse())
1313+
return II->getArgOperand(0);
1314+
return nullptr;
1315+
};
1316+
1317+
Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1318+
if (!V0)
1319+
return false;
1320+
Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1321+
if (!V1)
1322+
return false;
1323+
1324+
auto *VTy = cast<VectorType>(V0->getType());
1325+
if (V1->getType() != VTy)
1326+
return false;
1327+
const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
1328+
const auto &II1 = *cast<IntrinsicInst>(I.getOperand(1));
1329+
unsigned ReductionOpc =
1330+
getArithmeticReductionInstruction(II0.getIntrinsicID());
1331+
1332+
InstructionCost OldCost = 0;
1333+
InstructionCost NewCost = 0;
1334+
InstructionCost CostOfRedOperand0 = 0;
1335+
InstructionCost CostOfRed0 = 0;
1336+
InstructionCost CostOfRedOperand1 = 0;
1337+
InstructionCost CostOfRed1 = 0;
1338+
analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1339+
analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1340+
OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
1341+
NewCost =
1342+
CostOfRedOperand0 + CostOfRedOperand1 +
1343+
TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
1344+
TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
1345+
if (NewCost >= OldCost || !NewCost.isValid())
1346+
return false;
1347+
1348+
LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1349+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1350+
<< "\n");
1351+
Value *VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1352+
if (auto *PDInst = dyn_cast<PossiblyDisjointInst>(&I))
1353+
if (auto *PDVectorBO = dyn_cast<PossiblyDisjointInst>(VectorBO))
1354+
PDVectorBO->setIsDisjoint(PDInst->isDisjoint());
1355+
1356+
Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1357+
replaceValue(I, *Rdx);
1358+
return true;
1359+
}
1360+
12451361
// Check if memory loc modified between two instrs in the same BB
12461362
static bool isMemModifiedBetween(BasicBlock::iterator Begin,
12471363
BasicBlock::iterator End,
@@ -3382,6 +3498,7 @@ bool VectorCombine::run() {
33823498
if (Instruction::isBinaryOp(Opcode)) {
33833499
MadeChange |= foldExtractExtract(I);
33843500
MadeChange |= foldExtractedCmps(I);
3501+
MadeChange |= foldBinopOfReductions(I);
33853502
}
33863503
break;
33873504
}

0 commit comments

Comments
 (0)