-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SLP]Use getExtendedReduction cost and fix reduction cost calculations #117350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLP]Use getExtendedReduction cost and fix reduction cost calculations #117350
Conversation
Created using spr 1.3.5
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Alexey Bataev (alexey-bataev) ChangesPatch uses getExtendedReduction for reductions of ext-based nodes + adds Full diff: https://github.com/llvm/llvm-project/pull/117350.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index b3583e2819ee4c..d4bd504bf9ba31 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Type *ResTy, VectorType *Ty,
FastMathFlags FMF,
TTI::TargetCostKind CostKind) {
+ if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
+ FTy && Opcode == Instruction::Add &&
+ FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
+ // Represent vector_reduce_add(ZExt(<n x i1>)) as
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+ auto *IntTy =
+ IntegerType::get(ResTy->getContext(), FTy->getNumElements());
+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
+ return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
+ TTI::CastContextHint::None, CostKind) +
+ thisT()->getIntrinsicInstrCost(ICA, CostKind);
+ }
// Without any native support, this is equivalent to the cost of
// vecreduce.opcode(ext(Ty A)).
VectorType *ExtTy = VectorType::get(ResTy, Ty);
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 2b16dcbcd8695b..026b1e694e6a64 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
+ if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() &&
+ LT.second.getScalarType() == MVT::i1) {
+ // Represent vector_reduce_add(ZExt(<n x i1>)) as
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
+ return LT.first *
+ getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
+ }
+
if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
FMF, CostKind);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8e0ca2677bf0a9..46ae908f57ab89 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1371,22 +1371,46 @@ class BoUpSLP {
return VectorizableTree.front()->Scalars;
}
+ /// Returns the type/is-signed info for the root node in the graph without
+ /// casting.
+ std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
+ const TreeEntry &Root = *VectorizableTree.front().get();
+ if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
+ !Root.Scalars.front()->getType()->isIntegerTy())
+ return std::nullopt;
+ auto It = MinBWs.find(&Root);
+ if (It != MinBWs.end())
+ return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
+ It->second.first),
+ It->second.second);
+ if (Root.getOpcode() == Instruction::ZExt ||
+ Root.getOpcode() == Instruction::SExt)
+ return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
+ Root.getOpcode() == Instruction::SExt);
+ return std::nullopt;
+ }
+
/// Checks if the root graph node can be emitted with narrower bitwidth at
/// codegen and returns it signedness, if so.
bool isSignedMinBitwidthRootNode() const {
return MinBWs.at(VectorizableTree.front().get()).second;
}
- /// Returns reduction bitwidth and signedness, if it does not match the
- /// original requested size.
- std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
+ /// Returns reduction type after minbitdth analysis.
+ FixedVectorType *getReductionType() const {
if (ReductionBitWidth == 0 ||
+ !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
ReductionBitWidth >=
DL->getTypeSizeInBits(
VectorizableTree.front()->Scalars.front()->getType()))
- return std::nullopt;
- return std::make_pair(ReductionBitWidth,
- MinBWs.at(VectorizableTree.front().get()).second);
+ return getWidenedType(
+ VectorizableTree.front()->Scalars.front()->getType(),
+ VectorizableTree.front()->getVectorFactor());
+ return getWidenedType(
+ IntegerType::get(
+ VectorizableTree.front()->Scalars.front()->getContext(),
+ ReductionBitWidth),
+ VectorizableTree.front()->getVectorFactor());
}
/// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return CommonCost;
auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
+
+ bool IsArithmeticExtendedReduction =
+ E->Idx == 0 && UserIgnoreList &&
+ all_of(*UserIgnoreList, [](Value *V) {
+ auto *I = cast<Instruction>(V);
+ return is_contained({Instruction::Add, Instruction::FAdd,
+ Instruction::Mul, Instruction::FMul,
+ Instruction::And, Instruction::Or,
+ Instruction::Xor},
+ I->getOpcode());
+ });
+ if (IsArithmeticExtendedReduction &&
+ (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
+ return CommonCost;
return CommonCost +
TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
VecOpcode == Opcode ? VI : nullptr);
@@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
unsigned SrcSize = It->second.first;
unsigned DstSize = ReductionBitWidth;
unsigned Opcode = Instruction::Trunc;
- if (SrcSize < DstSize)
- Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
- auto *SrcVecTy =
- getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
- auto *DstVecTy =
- getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
- TTI::CastContextHint CCH = getCastContextHint(E);
- InstructionCost CastCost;
- switch (E.getOpcode()) {
- case Instruction::SExt:
- case Instruction::ZExt:
- case Instruction::Trunc: {
- const TreeEntry *OpTE = getOperandEntry(&E, 0);
- CCH = getCastContextHint(*OpTE);
- break;
- }
- default:
- break;
+ if (SrcSize < DstSize) {
+ bool IsArithmeticExtendedReduction =
+ all_of(*UserIgnoreList, [](Value *V) {
+ auto *I = cast<Instruction>(V);
+ return is_contained({Instruction::Add, Instruction::FAdd,
+ Instruction::Mul, Instruction::FMul,
+ Instruction::And, Instruction::Or,
+ Instruction::Xor},
+ I->getOpcode());
+ });
+ if (IsArithmeticExtendedReduction)
+ Opcode =
+ Instruction::BitCast; // Handle it by getExtendedReductionCost
+ else
+ Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+ }
+ if (Opcode != Instruction::BitCast) {
+ auto *SrcVecTy =
+ getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
+ auto *DstVecTy =
+ getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
+ TTI::CastContextHint CCH = getCastContextHint(E);
+ InstructionCost CastCost;
+ switch (E.getOpcode()) {
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ case Instruction::Trunc: {
+ const TreeEntry *OpTE = getOperandEntry(&E, 0);
+ CCH = getCastContextHint(*OpTE);
+ break;
+ }
+ default:
+ break;
+ }
+ CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
+ TTI::TCK_RecipThroughput);
+ Cost += CastCost;
+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
+ << " for final resize for reduction from " << SrcVecTy
+ << " to " << DstVecTy << "\n";
+ dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
- CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
- TTI::TCK_RecipThroughput);
- Cost += CastCost;
- LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
- << " for final resize for reduction from " << SrcVecTy
- << " to " << DstVecTy << "\n";
- dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
}
@@ -19815,8 +19869,8 @@ class HorizontalReduction {
// Estimate cost.
InstructionCost TreeCost = V.getTreeCost(VL);
- InstructionCost ReductionCost = getReductionCost(
- TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
+ InstructionCost ReductionCost =
+ getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
<< " for reduction\n");
@@ -20107,14 +20161,14 @@ class HorizontalReduction {
private:
/// Calculate the cost of a reduction.
- InstructionCost getReductionCost(
- TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
- bool IsCmpSelMinMax, FastMathFlags FMF,
- const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
+ InstructionCost getReductionCost(TargetTransformInfo *TTI,
+ ArrayRef<Value *> ReducedVals,
+ bool IsCmpSelMinMax, FastMathFlags FMF,
+ const BoUpSLP &R) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *ScalarTy = ReducedVals.front()->getType();
unsigned ReduxWidth = ReducedVals.size();
- FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
+ FixedVectorType *VectorTy = R.getReductionType();
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
// the reduction value can be calculated at the compile time.
@@ -20172,21 +20226,16 @@ class HorizontalReduction {
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
/*Extract*/ false, TTI::TCK_RecipThroughput);
} else {
- auto [Bitwidth, IsSigned] =
- BitwidthAndSign.value_or(std::make_pair(0u, false));
- if (RdxKind == RecurKind::Add && Bitwidth == 1) {
- // Represent vector_reduce_add(ZExt(<n x i1>)) to
- // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
- auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
- IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
- VectorCost =
- TTI->getCastInstrCost(Instruction::BitCast, IntTy,
- getWidenedType(ScalarTy, ReduxWidth),
- TTI::CastContextHint::None, CostKind) +
- TTI->getIntrinsicInstrCost(ICA, CostKind);
- } else {
+ Type *RedTy = VectorTy->getElementType();
+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
+ std::make_pair(RedTy, true));
+ if (RType == RedTy) {
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
FMF, CostKind);
+ } else {
+ VectorCost = TTI->getExtendedReductionCost(
+ RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
+ FMF, CostKind);
}
}
}
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index bc24a44cecbe39..85131758853b3d 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -877,20 +877,10 @@ entry:
define i64 @red_zext_ld_4xi64(ptr %ptr) {
; CHECK-LABEL: @red_zext_ld_4xi64(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
-; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
-; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
-; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
-; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
-; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
-; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
-; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
-; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
-; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
-; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
-; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
-; CHECK-NEXT: [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
+; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
+; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
+; CHECK-NEXT: [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
; CHECK-NEXT: ret i64 [[ADD_3]]
;
entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
index e4d20a6db8fa67..09c11bbefd4a35 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll
@@ -8,7 +8,7 @@
; YAML-NEXT: Function: test
; YAML-NEXT: Args:
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
-; YAML-NEXT: - Cost: '-1'
+; YAML-NEXT: - Cost: '-10'
; YAML-NEXT: - String: ' and with tree size '
; YAML-NEXT: - TreeSize: '8'
; YAML-NEXT:...
|
@@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> { | |||
Type *ResTy, VectorType *Ty, | |||
FastMathFlags FMF, | |||
TTI::TargetCostKind CostKind) { | |||
if (auto *FTy = dyn_cast<FixedVectorType>(Ty); | |||
FTy && Opcode == Instruction::Add && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing the IsUnsigned check here - this allows sext too.
@@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost( | |||
|
|||
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy); | |||
|
|||
if (Opcode == Instruction::Add && LT.second.isFixedLengthVector() && | |||
LT.second.getScalarType() == MVT::i1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, missing the unsigned check.
Created using spr 1.3.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - Note that I'm only confident in the TTI bits. The SLP bits aren't obviously broken, but that's all I can say.
…SExt PR llvm#117350 made changes to the SLP vectorizer which introduced a regression on ARM vectorization benchmarks. This was due to the changes assuming that SExt/ZExt vector instructions have constant cost. This behaviour is expected for RISCV but not on ARM where we take into account source and destination type of SExt/ZExt instructions when calculating vector cost. Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
…SExt PR llvm#117350 made changes to the SLP vectorizer which introduced a regression on ARM vectorization benchmarks. This was due to the changes assuming that SExt/ZExt vector instructions have constant cost. This behaviour is expected for RISCV but not on ARM where we take into account source and destination type of SExt/ZExt instructions when calculating vector cost. Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
…SExt PR llvm#117350 made changes to the SLP vectorizer which introduced a regression on ARM vectorization benchmarks. This was due to the changes assuming that SExt/ZExt vector instructions have constant cost. This behaviour is expected for RISCV but not on ARM where we take into account source and destination type of SExt/ZExt instructions when calculating vector cost. Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
…SExt PR llvm#117350 made changes to the SLP vectorizer which introduced a regression on ARM vectorization benchmarks. This was due to the changes assuming that SExt/ZExt vector instructions have constant cost. This behaviour is expected for RISCV but not on ARM where we take into account source and destination type of SExt/ZExt instructions when calculating vector cost. Change-Id: I6f995dcde26e5aaf62b779b63e52988fb333f941
PR #117350 made changes to the SLP vectorizer which introduced a regression on some ARM benchmarks. Investigation narrowed it down to suboptimal codegen for benchmarks that previously only used scalar (U/S)MLAL instructions. The linked change meant the SLPVectorizer thought that these could be vectorized. This change makes the cost of muls in (U/S)MLAL patterns slightly cheaper to make sure scalar instructions are preferred in these cases over SLP vectorization on targets supporting DSP
Patch uses getExtendedReduction for reductions of ext-based nodes + adds
cost estimation for ctpop-kind reductions into basic implementation and
RISCV-V specific vcpop cost estimation.