Skip to content

Commit 0ad6be1

Browse files
authored
[SLPVectorizer, TargetTransformInfo, SystemZ] Improve SLP getGatherCost(). (#112491)
As vector element loads are free on SystemZ, this patch improves the cost computation in getGatherCost() to reflect this. getScalarizationOverhead() gets an optional parameter which can hold the actual Values so that they in turn can be passed (by BasicTTIImpl) to getVectorInstrCost(). SystemZTTIImpl::getVectorInstrCost() will now recognize a LoadInst and typically return a 0 cost for it, with some exceptions.
1 parent cbf495f commit 0ad6be1

14 files changed

+209
-88
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -909,11 +909,13 @@ class TargetTransformInfo {
909909

910910
/// Estimate the overhead of scalarizing an instruction. Insert and Extract
911911
/// are set if the demanded result elements need to be inserted and/or
912-
/// extracted from vectors.
912+
/// extracted from vectors. The involved values may be passed in VL if
913+
/// Insert is true.
913914
InstructionCost getScalarizationOverhead(VectorType *Ty,
914915
const APInt &DemandedElts,
915916
bool Insert, bool Extract,
916-
TTI::TargetCostKind CostKind) const;
917+
TTI::TargetCostKind CostKind,
918+
ArrayRef<Value *> VL = {}) const;
917919

918920
/// Estimate the overhead of scalarizing an instructions unique
919921
/// non-constant operands. The (potentially vector) types to use for each of
@@ -2001,10 +2003,10 @@ class TargetTransformInfo::Concept {
20012003
unsigned ScalarOpdIdx) = 0;
20022004
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
20032005
int ScalarOpdIdx) = 0;
2004-
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
2005-
const APInt &DemandedElts,
2006-
bool Insert, bool Extract,
2007-
TargetCostKind CostKind) = 0;
2006+
virtual InstructionCost
2007+
getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
2008+
bool Insert, bool Extract, TargetCostKind CostKind,
2009+
ArrayRef<Value *> VL = {}) = 0;
20082010
virtual InstructionCost
20092011
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
20102012
ArrayRef<Type *> Tys,
@@ -2585,9 +2587,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25852587
InstructionCost getScalarizationOverhead(VectorType *Ty,
25862588
const APInt &DemandedElts,
25872589
bool Insert, bool Extract,
2588-
TargetCostKind CostKind) override {
2590+
TargetCostKind CostKind,
2591+
ArrayRef<Value *> VL = {}) override {
25892592
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
2590-
CostKind);
2593+
CostKind, VL);
25912594
}
25922595
InstructionCost
25932596
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ class TargetTransformInfoImplBase {
404404
InstructionCost getScalarizationOverhead(VectorType *Ty,
405405
const APInt &DemandedElts,
406406
bool Insert, bool Extract,
407-
TTI::TargetCostKind CostKind) const {
407+
TTI::TargetCostKind CostKind,
408+
ArrayRef<Value *> VL = {}) const {
408409
return 0;
409410
}
410411

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,24 +780,28 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
780780
InstructionCost getScalarizationOverhead(VectorType *InTy,
781781
const APInt &DemandedElts,
782782
bool Insert, bool Extract,
783-
TTI::TargetCostKind CostKind) {
783+
TTI::TargetCostKind CostKind,
784+
ArrayRef<Value *> VL = {}) {
784785
/// FIXME: a bitfield is not a reasonable abstraction for talking about
785786
/// which elements are needed from a scalable vector
786787
if (isa<ScalableVectorType>(InTy))
787788
return InstructionCost::getInvalid();
788789
auto *Ty = cast<FixedVectorType>(InTy);
789790

790791
assert(DemandedElts.getBitWidth() == Ty->getNumElements() &&
792+
(VL.empty() || VL.size() == Ty->getNumElements()) &&
791793
"Vector size mismatch");
792794

793795
InstructionCost Cost = 0;
794796

795797
for (int i = 0, e = Ty->getNumElements(); i < e; ++i) {
796798
if (!DemandedElts[i])
797799
continue;
798-
if (Insert)
800+
if (Insert) {
801+
Value *InsertedVal = VL.empty() ? nullptr : VL[i];
799802
Cost += thisT()->getVectorInstrCost(Instruction::InsertElement, Ty,
800-
CostKind, i, nullptr, nullptr);
803+
CostKind, i, nullptr, InsertedVal);
804+
}
801805
if (Extract)
802806
Cost += thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty,
803807
CostKind, i, nullptr, nullptr);

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
622622

623623
InstructionCost TargetTransformInfo::getScalarizationOverhead(
624624
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
625-
TTI::TargetCostKind CostKind) const {
625+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) const {
626626
return TTIImpl->getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
627-
CostKind);
627+
CostKind, VL);
628628
}
629629

630630
InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3363,7 +3363,7 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
33633363

33643364
InstructionCost AArch64TTIImpl::getScalarizationOverhead(
33653365
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
3366-
TTI::TargetCostKind CostKind) {
3366+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
33673367
if (isa<ScalableVectorType>(Ty))
33683368
return InstructionCost::getInvalid();
33693369
if (Ty->getElementType()->isFloatingPointTy())

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
423423
InstructionCost getScalarizationOverhead(VectorType *Ty,
424424
const APInt &DemandedElts,
425425
bool Insert, bool Extract,
426-
TTI::TargetCostKind CostKind);
426+
TTI::TargetCostKind CostKind,
427+
ArrayRef<Value *> VL = {});
427428

428429
/// Return the cost of the scaling factor used in the addressing
429430
/// mode represented by AM for this target, for a load/store

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ static unsigned isM1OrSmaller(MVT VT) {
669669

670670
InstructionCost RISCVTTIImpl::getScalarizationOverhead(
671671
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
672-
TTI::TargetCostKind CostKind) {
672+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
673673
if (isa<ScalableVectorType>(Ty))
674674
return InstructionCost::getInvalid();
675675

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
149149
InstructionCost getScalarizationOverhead(VectorType *Ty,
150150
const APInt &DemandedElts,
151151
bool Insert, bool Extract,
152-
TTI::TargetCostKind CostKind);
152+
TTI::TargetCostKind CostKind,
153+
ArrayRef<Value *> VL = {});
153154

154155
InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
155156
TTI::TargetCostKind CostKind);

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,42 @@ bool SystemZTTIImpl::hasDivRemOp(Type *DataType, bool IsSigned) {
468468
return (VT.isScalarInteger() && TLI->isTypeLegal(VT));
469469
}
470470

471+
static bool isFreeEltLoad(Value *Op) {
472+
if (isa<LoadInst>(Op) && Op->hasOneUse()) {
473+
const Instruction *UserI = cast<Instruction>(*Op->user_begin());
474+
return !isa<StoreInst>(UserI); // Prefer MVC
475+
}
476+
return false;
477+
}
478+
479+
InstructionCost SystemZTTIImpl::getScalarizationOverhead(
480+
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
481+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
482+
unsigned NumElts = cast<FixedVectorType>(Ty)->getNumElements();
483+
InstructionCost Cost = 0;
484+
485+
if (Insert && Ty->isIntOrIntVectorTy(64)) {
486+
// VLVGP will insert two GPRs with one instruction, while VLE will load
487+
// an element directly with no extra cost
488+
assert((VL.empty() || VL.size() == NumElts) &&
489+
"Type does not match the number of values.");
490+
InstructionCost CurrVectorCost = 0;
491+
for (unsigned Idx = 0; Idx < NumElts; ++Idx) {
492+
if (DemandedElts[Idx] && !(VL.size() && isFreeEltLoad(VL[Idx])))
493+
++CurrVectorCost;
494+
if (Idx % 2 == 1) {
495+
Cost += std::min(InstructionCost(1), CurrVectorCost);
496+
CurrVectorCost = 0;
497+
}
498+
}
499+
Insert = false;
500+
}
501+
502+
Cost += BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
503+
CostKind, VL);
504+
return Cost;
505+
}
506+
471507
// Return the bit size for the scalar type or vector element
472508
// type. getScalarSizeInBits() returns 0 for a pointer type.
473509
static unsigned getScalarSizeInBits(Type *Ty) {
@@ -609,7 +645,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
609645
if (DivRemConst) {
610646
SmallVector<Type *> Tys(Args.size(), Ty);
611647
return VF * DivMulSeqCost +
612-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
648+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
613649
}
614650
if ((SignedDivRem || UnsignedDivRem) && VF > 4)
615651
// Temporary hack: disable high vectorization factors with integer
@@ -636,7 +672,7 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
636672
SmallVector<Type *> Tys(Args.size(), Ty);
637673
InstructionCost Cost =
638674
(VF * ScalarCost) +
639-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
675+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
640676
// FIXME: VF 2 for these FP operations are currently just as
641677
// expensive as for VF 4.
642678
if (VF == 2)
@@ -654,8 +690,9 @@ InstructionCost SystemZTTIImpl::getArithmeticInstrCost(
654690
// There is no native support for FRem.
655691
if (Opcode == Instruction::FRem) {
656692
SmallVector<Type *> Tys(Args.size(), Ty);
657-
InstructionCost Cost = (VF * LIBCALL_COST) +
658-
getScalarizationOverhead(VTy, Args, Tys, CostKind);
693+
InstructionCost Cost =
694+
(VF * LIBCALL_COST) +
695+
BaseT::getScalarizationOverhead(VTy, Args, Tys, CostKind);
659696
// FIXME: VF 2 for float is currently just as expensive as for VF 4.
660697
if (VF == 2 && ScalarBits == 32)
661698
Cost *= 2;
@@ -975,10 +1012,10 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
9751012
(Opcode == Instruction::FPToSI || Opcode == Instruction::FPToUI))
9761013
NeedsExtracts = false;
9771014

978-
TotCost += getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
979-
NeedsExtracts, CostKind);
980-
TotCost += getScalarizationOverhead(DstVecTy, NeedsInserts,
981-
/*Extract*/ false, CostKind);
1015+
TotCost += BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1016+
NeedsExtracts, CostKind);
1017+
TotCost += BaseT::getScalarizationOverhead(DstVecTy, NeedsInserts,
1018+
/*Extract*/ false, CostKind);
9821019

9831020
// FIXME: VF 2 for float<->i32 is currently just as expensive as for VF 4.
9841021
if (VF == 2 && SrcScalarBits == 32 && DstScalarBits == 32)
@@ -990,8 +1027,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
9901027
if (Opcode == Instruction::FPTrunc) {
9911028
if (SrcScalarBits == 128) // fp128 -> double/float + inserts of elements.
9921029
return VF /*ldxbr/lexbr*/ +
993-
getScalarizationOverhead(DstVecTy, /*Insert*/ true,
994-
/*Extract*/ false, CostKind);
1030+
BaseT::getScalarizationOverhead(DstVecTy, /*Insert*/ true,
1031+
/*Extract*/ false, CostKind);
9951032
else // double -> float
9961033
return VF / 2 /*vledb*/ + std::max(1U, VF / 4 /*vperm*/);
9971034
}
@@ -1004,8 +1041,8 @@ InstructionCost SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
10041041
return VF * 2;
10051042
}
10061043
// -> fp128. VF * lxdb/lxeb + extraction of elements.
1007-
return VF + getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1008-
/*Extract*/ true, CostKind);
1044+
return VF + BaseT::getScalarizationOverhead(SrcVecTy, /*Insert*/ false,
1045+
/*Extract*/ true, CostKind);
10091046
}
10101047
}
10111048

@@ -1114,10 +1151,17 @@ InstructionCost SystemZTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
11141151
TTI::TargetCostKind CostKind,
11151152
unsigned Index, Value *Op0,
11161153
Value *Op1) {
1117-
// vlvgp will insert two grs into a vector register, so only count half the
1118-
// number of instructions.
1119-
if (Opcode == Instruction::InsertElement && Val->isIntOrIntVectorTy(64))
1120-
return ((Index % 2 == 0) ? 1 : 0);
1154+
if (Opcode == Instruction::InsertElement) {
1155+
// Vector Element Load.
1156+
if (Op1 != nullptr && isFreeEltLoad(Op1))
1157+
return 0;
1158+
1159+
// vlvgp will insert two grs into a vector register, so count half the
1160+
// number of instructions as an estimate when we don't have the full
1161+
// picture (as in getScalarizationOverhead()).
1162+
if (Val->isIntOrIntVectorTy(64))
1163+
return ((Index % 2 == 0) ? 1 : 0);
1164+
}
11211165

11221166
if (Opcode == Instruction::ExtractElement) {
11231167
int Cost = ((getScalarSizeInBits(Val) == 1) ? 2 /*+test-under-mask*/ : 1);

llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ class SystemZTTIImpl : public BasicTTIImplBase<SystemZTTIImpl> {
8181
bool hasDivRemOp(Type *DataType, bool IsSigned);
8282
bool prefersVectorizedAddressing() { return false; }
8383
bool LSRWithInstrQueries() { return true; }
84+
InstructionCost getScalarizationOverhead(VectorType *Ty,
85+
const APInt &DemandedElts,
86+
bool Insert, bool Extract,
87+
TTI::TargetCostKind CostKind,
88+
ArrayRef<Value *> VL = {});
8489
bool supportsEfficientVectorElementLoadStore() { return true; }
8590
bool enableInterleavedAccessVectorization() { return true; }
8691

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4854,10 +4854,9 @@ InstructionCost X86TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
48544854
RegisterFileMoveCost;
48554855
}
48564856

4857-
InstructionCost
4858-
X86TTIImpl::getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts,
4859-
bool Insert, bool Extract,
4860-
TTI::TargetCostKind CostKind) {
4857+
InstructionCost X86TTIImpl::getScalarizationOverhead(
4858+
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
4859+
TTI::TargetCostKind CostKind, ArrayRef<Value *> VL) {
48614860
assert(DemandedElts.getBitWidth() ==
48624861
cast<FixedVectorType>(Ty)->getNumElements() &&
48634862
"Vector size mismatch");

llvm/lib/Target/X86/X86TargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
169169
InstructionCost getScalarizationOverhead(VectorType *Ty,
170170
const APInt &DemandedElts,
171171
bool Insert, bool Extract,
172-
TTI::TargetCostKind CostKind);
172+
TTI::TargetCostKind CostKind,
173+
ArrayRef<Value *> VL = {});
173174
InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
174175
int VF,
175176
const APInt &DemandedDstElts,

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3110,9 +3110,8 @@ class BoUpSLP {
31103110
SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries,
31113111
unsigned NumParts, bool ForOrder = false);
31123112

3113-
/// \returns the scalarization cost for this list of values. Assuming that
3114-
/// this subtree gets vectorized, we may need to extract the values from the
3115-
/// roots. This method calculates the cost of extracting the values.
3113+
/// \returns the cost of gathering (inserting) the values in \p VL into a
3114+
/// vector.
31163115
/// \param ForPoisonSrc true if initial vector is poison, false otherwise.
31173116
InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
31183117
Type *ScalarTy) const;
@@ -13498,9 +13497,10 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc,
1349813497
TTI::SK_InsertSubvector, VecTy, std::nullopt, CostKind,
1349913498
I * ScalarTyNumElements, cast<FixedVectorType>(ScalarTy));
1350013499
} else {
13501-
Cost = TTI->getScalarizationOverhead(VecTy, ~ShuffledElements,
13500+
Cost = TTI->getScalarizationOverhead(VecTy,
13501+
/*DemandedElts*/ ~ShuffledElements,
1350213502
/*Insert*/ true,
13503-
/*Extract*/ false, CostKind);
13503+
/*Extract*/ false, CostKind, VL);
1350413504
}
1350513505
}
1350613506
if (DuplicateNonConst)

0 commit comments

Comments
 (0)