Skip to content

Commit 7523086

Browse files
[SLP]Use getExtendedReduction cost and fix reduction cost calculations
Patch uses getExtendedReduction for reductions of ext-based nodes + adds cost estimation for ctpop-kind reductions into basic implementation and RISCV-V specific vcpop cost estimation. Reviewers: RKSimon, preames Reviewed By: preames Pull Request: #117350
1 parent b870336 commit 7523086

File tree

5 files changed

+125
-66
lines changed

5 files changed

+125
-66
lines changed

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
27652765
Type *ResTy, VectorType *Ty,
27662766
FastMathFlags FMF,
27672767
TTI::TargetCostKind CostKind) {
2768+
if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
2769+
FTy && IsUnsigned && Opcode == Instruction::Add &&
2770+
FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
2771+
// Represent vector_reduce_add(ZExt(<n x i1>)) as
2772+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
2773+
auto *IntTy =
2774+
IntegerType::get(ResTy->getContext(), FTy->getNumElements());
2775+
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
2776+
return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
2777+
TTI::CastContextHint::None, CostKind) +
2778+
thisT()->getIntrinsicInstrCost(ICA, CostKind);
2779+
}
27682780
// Without any native support, this is equivalent to the cost of
27692781
// vecreduce.opcode(ext(Ty A)).
27702782
VectorType *ExtTy = VectorType::get(ResTy, Ty);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
16201620

16211621
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
16221622

1623+
if (IsUnsigned && Opcode == Instruction::Add &&
1624+
LT.second.isFixedLengthVector() && LT.second.getScalarType() == MVT::i1) {
1625+
// Represent vector_reduce_add(ZExt(<n x i1>)) as
1626+
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
1627+
return LT.first *
1628+
getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
1629+
}
1630+
16231631
if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
16241632
return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
16251633
FMF, CostKind);

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 100 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,22 +1395,46 @@ class BoUpSLP {
13951395
return VectorizableTree.front()->Scalars;
13961396
}
13971397

1398+
/// Returns the type/is-signed info for the root node in the graph without
1399+
/// casting.
1400+
std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
1401+
const TreeEntry &Root = *VectorizableTree.front().get();
1402+
if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
1403+
!Root.Scalars.front()->getType()->isIntegerTy())
1404+
return std::nullopt;
1405+
auto It = MinBWs.find(&Root);
1406+
if (It != MinBWs.end())
1407+
return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
1408+
It->second.first),
1409+
It->second.second);
1410+
if (Root.getOpcode() == Instruction::ZExt ||
1411+
Root.getOpcode() == Instruction::SExt)
1412+
return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
1413+
Root.getOpcode() == Instruction::SExt);
1414+
return std::nullopt;
1415+
}
1416+
13981417
/// Checks if the root graph node can be emitted with narrower bitwidth at
13991418
/// codegen and returns it signedness, if so.
14001419
bool isSignedMinBitwidthRootNode() const {
14011420
return MinBWs.at(VectorizableTree.front().get()).second;
14021421
}
14031422

1404-
/// Returns reduction bitwidth and signedness, if it does not match the
1405-
/// original requested size.
1406-
std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1423+
/// Returns reduction type after minbitdth analysis.
1424+
FixedVectorType *getReductionType() const {
14071425
if (ReductionBitWidth == 0 ||
1426+
!VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
14081427
ReductionBitWidth >=
14091428
DL->getTypeSizeInBits(
14101429
VectorizableTree.front()->Scalars.front()->getType()))
1411-
return std::nullopt;
1412-
return std::make_pair(ReductionBitWidth,
1413-
MinBWs.at(VectorizableTree.front().get()).second);
1430+
return getWidenedType(
1431+
VectorizableTree.front()->Scalars.front()->getType(),
1432+
VectorizableTree.front()->getVectorFactor());
1433+
return getWidenedType(
1434+
IntegerType::get(
1435+
VectorizableTree.front()->Scalars.front()->getContext(),
1436+
ReductionBitWidth),
1437+
VectorizableTree.front()->getVectorFactor());
14141438
}
14151439

14161440
/// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11384,6 +11408,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1138411408
return CommonCost;
1138511409
auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
1138611410
TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
11411+
11412+
bool IsArithmeticExtendedReduction =
11413+
E->Idx == 0 && UserIgnoreList &&
11414+
all_of(*UserIgnoreList, [](Value *V) {
11415+
auto *I = cast<Instruction>(V);
11416+
return is_contained({Instruction::Add, Instruction::FAdd,
11417+
Instruction::Mul, Instruction::FMul,
11418+
Instruction::And, Instruction::Or,
11419+
Instruction::Xor},
11420+
I->getOpcode());
11421+
});
11422+
if (IsArithmeticExtendedReduction &&
11423+
(VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
11424+
return CommonCost;
1138711425
return CommonCost +
1138811426
TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
1138911427
VecOpcode == Opcode ? VI : nullptr);
@@ -12748,32 +12786,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1274812786
unsigned SrcSize = It->second.first;
1274912787
unsigned DstSize = ReductionBitWidth;
1275012788
unsigned Opcode = Instruction::Trunc;
12751-
if (SrcSize < DstSize)
12752-
Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12753-
auto *SrcVecTy =
12754-
getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12755-
auto *DstVecTy =
12756-
getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12757-
TTI::CastContextHint CCH = getCastContextHint(E);
12758-
InstructionCost CastCost;
12759-
switch (E.getOpcode()) {
12760-
case Instruction::SExt:
12761-
case Instruction::ZExt:
12762-
case Instruction::Trunc: {
12763-
const TreeEntry *OpTE = getOperandEntry(&E, 0);
12764-
CCH = getCastContextHint(*OpTE);
12765-
break;
12766-
}
12767-
default:
12768-
break;
12789+
if (SrcSize < DstSize) {
12790+
bool IsArithmeticExtendedReduction =
12791+
all_of(*UserIgnoreList, [](Value *V) {
12792+
auto *I = cast<Instruction>(V);
12793+
return is_contained({Instruction::Add, Instruction::FAdd,
12794+
Instruction::Mul, Instruction::FMul,
12795+
Instruction::And, Instruction::Or,
12796+
Instruction::Xor},
12797+
I->getOpcode());
12798+
});
12799+
if (IsArithmeticExtendedReduction)
12800+
Opcode =
12801+
Instruction::BitCast; // Handle it by getExtendedReductionCost
12802+
else
12803+
Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12804+
}
12805+
if (Opcode != Instruction::BitCast) {
12806+
auto *SrcVecTy =
12807+
getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12808+
auto *DstVecTy =
12809+
getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12810+
TTI::CastContextHint CCH = getCastContextHint(E);
12811+
InstructionCost CastCost;
12812+
switch (E.getOpcode()) {
12813+
case Instruction::SExt:
12814+
case Instruction::ZExt:
12815+
case Instruction::Trunc: {
12816+
const TreeEntry *OpTE = getOperandEntry(&E, 0);
12817+
CCH = getCastContextHint(*OpTE);
12818+
break;
12819+
}
12820+
default:
12821+
break;
12822+
}
12823+
CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12824+
TTI::TCK_RecipThroughput);
12825+
Cost += CastCost;
12826+
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12827+
<< " for final resize for reduction from " << SrcVecTy
12828+
<< " to " << DstVecTy << "\n";
12829+
dbgs() << "SLP: Current total cost = " << Cost << "\n");
1276912830
}
12770-
CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12771-
TTI::TCK_RecipThroughput);
12772-
Cost += CastCost;
12773-
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12774-
<< " for final resize for reduction from " << SrcVecTy
12775-
<< " to " << DstVecTy << "\n";
12776-
dbgs() << "SLP: Current total cost = " << Cost << "\n");
1277712831
}
1277812832
}
1277912833

@@ -19951,8 +20005,8 @@ class HorizontalReduction {
1995120005

1995220006
// Estimate cost.
1995320007
InstructionCost TreeCost = V.getTreeCost(VL);
19954-
InstructionCost ReductionCost = getReductionCost(
19955-
TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
20008+
InstructionCost ReductionCost =
20009+
getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
1995620010
InstructionCost Cost = TreeCost + ReductionCost;
1995720011
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1995820012
<< " for reduction\n");
@@ -20243,14 +20297,14 @@ class HorizontalReduction {
2024320297

2024420298
private:
2024520299
/// Calculate the cost of a reduction.
20246-
InstructionCost getReductionCost(
20247-
TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20248-
bool IsCmpSelMinMax, FastMathFlags FMF,
20249-
const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
20300+
InstructionCost getReductionCost(TargetTransformInfo *TTI,
20301+
ArrayRef<Value *> ReducedVals,
20302+
bool IsCmpSelMinMax, FastMathFlags FMF,
20303+
const BoUpSLP &R) {
2025020304
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2025120305
Type *ScalarTy = ReducedVals.front()->getType();
2025220306
unsigned ReduxWidth = ReducedVals.size();
20253-
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
20307+
FixedVectorType *VectorTy = R.getReductionType();
2025420308
InstructionCost VectorCost = 0, ScalarCost;
2025520309
// If all of the reduced values are constant, the vector cost is 0, since
2025620310
// the reduction value can be calculated at the compile time.
@@ -20308,21 +20362,16 @@ class HorizontalReduction {
2030820362
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2030920363
/*Extract*/ false, TTI::TCK_RecipThroughput);
2031020364
} else {
20311-
auto [Bitwidth, IsSigned] =
20312-
BitwidthAndSign.value_or(std::make_pair(0u, false));
20313-
if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20314-
// Represent vector_reduce_add(ZExt(<n x i1>)) to
20315-
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20316-
auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20317-
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20318-
VectorCost =
20319-
TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20320-
getWidenedType(ScalarTy, ReduxWidth),
20321-
TTI::CastContextHint::None, CostKind) +
20322-
TTI->getIntrinsicInstrCost(ICA, CostKind);
20323-
} else {
20365+
Type *RedTy = VectorTy->getElementType();
20366+
auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
20367+
std::make_pair(RedTy, true));
20368+
if (RType == RedTy) {
2032420369
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
2032520370
FMF, CostKind);
20371+
} else {
20372+
VectorCost = TTI->getExtendedReductionCost(
20373+
RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
20374+
FMF, CostKind);
2032620375
}
2032720376
}
2032820377
}

llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -877,20 +877,10 @@ entry:
877877
define i64 @red_zext_ld_4xi64(ptr %ptr) {
878878
; CHECK-LABEL: @red_zext_ld_4xi64(
879879
; CHECK-NEXT: entry:
880-
; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
881-
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
882-
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
883-
; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
884-
; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
885-
; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
886-
; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
887-
; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
888-
; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
889-
; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
890-
; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
891-
; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
892-
; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
893-
; CHECK-NEXT: [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
880+
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
881+
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
882+
; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
883+
; CHECK-NEXT: [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
894884
; CHECK-NEXT: ret i64 [[ADD_3]]
895885
;
896886
entry:

llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
; YAML-NEXT: Function: test
99
; YAML-NEXT: Args:
1010
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
11-
; YAML-NEXT: - Cost: '-1'
11+
; YAML-NEXT: - Cost: '-10'
1212
; YAML-NEXT: - String: ' and with tree size '
1313
; YAML-NEXT: - TreeSize: '8'
1414
; YAML-NEXT:...

0 commit comments

Comments
 (0)