Skip to content

[SLP] Allow targets to add cost for nonstandard conditions #95328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,11 @@ class TargetTransformInfo {
bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const;

/// Whether or not there is any target-specific condition that imposes an
/// overhead for scalarization
bool hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) const;

/// Estimate the overhead of scalarizing an instructions unique
/// non-constant operands. The (potentially vector) types to use for each of
/// argument are passes via Tys.
Expand Down Expand Up @@ -1921,6 +1926,10 @@ class TargetTransformInfo::Concept {
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
ArrayRef<Type *> Tys,
TargetCostKind CostKind) = 0;

virtual bool
hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) = 0;
virtual bool supportsEfficientVectorElementLoadStore() = 0;
virtual bool supportsTailCalls() = 0;
virtual bool supportsTailCallFor(const CallBase *CB) = 0;
Expand Down Expand Up @@ -2456,6 +2465,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
CostKind);
}

bool
hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) override {
return Impl.hasScalarizationOverhead(VL, VTy, ScalarizationKind);
}

InstructionCost
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
ArrayRef<Type *> Tys,
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,12 @@ class TargetTransformInfoImplBase {
return 0;
}

bool
hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) const {
return false;
}

InstructionCost
getOperandsScalarizationOverhead(ArrayRef<const Value *> Args,
ArrayRef<Type *> Tys,
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
CostKind);
}

bool
hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VT,
std::pair<bool, bool> &ScalarizationKind) const {
return false;
}

/// Estimate the overhead of scalarizing an instructions unique
/// non-constant operands. The (potentially vector) types to use for each of
/// argument are passes via Tys.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,12 @@ InstructionCost TargetTransformInfo::getScalarizationOverhead(
CostKind);
}

bool TargetTransformInfo::hasScalarizationOverhead(
ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizeKind) const {
return TTIImpl->hasScalarizationOverhead(VL, VTy, ScalarizeKind);
}

InstructionCost TargetTransformInfo::getOperandsScalarizationOverhead(
ArrayRef<const Value *> Args, ArrayRef<Type *> Tys,
TTI::TargetCostKind CostKind) const {
Expand Down
43 changes: 40 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,18 @@ bool GCNTTIImpl::hasBranchDivergence(const Function *F) const {
return !F || !ST->isSingleLaneExecution(*F);
}

unsigned GCNTTIImpl::getNumberOfParts(Type *Tp) const {
if (auto VTy = dyn_cast<FixedVectorType>(Tp)) {
if (DL.getTypeSizeInBits(VTy->getElementType()) == 8) {
auto ElCount = VTy->getElementCount().getFixedValue();
return ElCount / 4;
}
}

std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
return LT.first.isValid() ? *LT.first.getValue() : 0;
}

unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const {
// NB: RCID is not an RCID. In fact it is 0 or 1 for scalar or vector
// registers. See getRegisterClassForType for the implementation.
Expand Down Expand Up @@ -337,9 +349,11 @@ unsigned GCNTTIImpl::getMinVectorRegisterBitWidth() const {
unsigned GCNTTIImpl::getMaximumVF(unsigned ElemWidth, unsigned Opcode) const {
if (Opcode == Instruction::Load || Opcode == Instruction::Store)
return 32 * 4 / ElemWidth;
return (ElemWidth == 16 && ST->has16BitInsts()) ? 2
: (ElemWidth == 32 && ST->hasPackedFP32Ops()) ? 2
: 1;

return (ElemWidth == 8) ? 4
: (ElemWidth == 16 && ST->has16BitInsts()) ? 2
: (ElemWidth == 32 && ST->hasPackedFP32Ops()) ? 2
: 1;
}

unsigned GCNTTIImpl::getLoadVectorFactor(unsigned VF, unsigned LoadSize,
Expand Down Expand Up @@ -1365,6 +1379,29 @@ int GCNTTIImpl::get64BitInstrCost(TTI::TargetCostKind CostKind) const {
: getQuarterRateInstrCost(CostKind);
}

bool GCNTTIImpl::hasScalarizationOverhead(
ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) const {
if (DL.getTypeSizeInBits(VTy->getElementType()) != 8)
return false;

for (Value *V : VL) {
Instruction *Inst = dyn_cast<Instruction>(V);
if (!V)
continue;
for (User *IU : Inst->users()) {
Instruction *UseInst = cast<Instruction>(IU);
if (UseInst->getOpcode() == Instruction::PHI ||
UseInst->getParent() != Inst->getParent()) {
ScalarizationKind = {true, true};
return true;
}
}
}

return false;
}

std::pair<InstructionCost, MVT>
GCNTTIImpl::getTypeLegalizationCost(Type *Ty) const {
std::pair<InstructionCost, MVT> Cost = BaseT::getTypeLegalizationCost(Ty);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
return TTI::PSK_FastHardware;
}

unsigned getNumberOfParts(Type *Tp) const;
unsigned getNumberOfRegisters(unsigned RCID) const;
TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind Vector) const;
unsigned getMinVectorRegisterBitWidth() const;
Expand Down Expand Up @@ -256,6 +257,9 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
FastMathFlags FMF,
TTI::TargetCostKind CostKind);

bool hasScalarizationOverhead(ArrayRef<Value *> VL, FixedVectorType *VTy,
std::pair<bool, bool> &ScalarizationKind) const;

/// Data cache line size for LoopDataPrefetch pass. Has no use before GFX12.
unsigned getCacheLineSize() const override { return 128; }

Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9084,6 +9084,14 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
}
InstructionCost CommonCost = 0;
std::pair<bool, bool> ScalarizationKind(false, false);
if (TTI->hasScalarizationOverhead(VL, FinalVecTy, ScalarizationKind)) {
APInt DemandedElts = APInt::getAllOnes(VL.size());
CommonCost -= TTI->getScalarizationOverhead(
VecTy, DemandedElts,
/*Insert*/ ScalarizationKind.first,
/*Extract*/ ScalarizationKind.second, CostKind);
}
SmallVector<int> Mask;
bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
if (!E->ReorderIndices.empty() &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,66 @@ bb:
ret <4 x i16> %ins.3
}

define <4 x i8> @uadd_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1) {
; GCN-LABEL: @uadd_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.uadd.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.uadd.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.uadd.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.uadd.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.uadd.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> poison, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3
}

define <4 x i8> @usub_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1) {
; GCN-LABEL: @usub_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.usub.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.usub.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.usub.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.usub.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.usub.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> poison, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3
}

declare i16 @llvm.uadd.sat.i16(i16, i16) #0
declare i16 @llvm.usub.sat.i16(i16, i16) #0
declare i16 @llvm.sadd.sat.i16(i16, i16) #0
declare i16 @llvm.ssub.sat.i16(i16, i16) #0

declare i8 @llvm.uadd.sat.i8(i8, i8) #0
declare i8 @llvm.usub.sat.i8(i8, i8) #0

declare i32 @llvm.uadd.sat.i32(i32, i32) #0
declare i32 @llvm.usub.sat.i32(i32, i32) #0
declare i32 @llvm.sadd.sat.i32(i32, i32) #0
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/AMDGPU/add_sub_sat.ll
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,67 @@ bb:
ret <4 x i16> %ins.3
}

define <4 x i8> @uadd_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1, ptr addrspace(1) %dst) {
; GCN-LABEL: @uadd_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.uadd.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.uadd.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.uadd.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.uadd.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.uadd.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> undef, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3
}
define <4 x i8> @usub_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1) {
; GCN-LABEL: @usub_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.usub.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.usub.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.usub.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.usub.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.usub.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> undef, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3

}


declare i16 @llvm.uadd.sat.i16(i16, i16) #0
declare i16 @llvm.usub.sat.i16(i16, i16) #0
declare i16 @llvm.sadd.sat.i16(i16, i16) #0
declare i16 @llvm.ssub.sat.i16(i16, i16) #0

declare i8 @llvm.uadd.sat.i8(i8, i8) #0
declare i8 @llvm.usub.sat.i8(i8, i8) #0

declare i32 @llvm.uadd.sat.i32(i32, i32) #0
declare i32 @llvm.usub.sat.i32(i32, i32) #0
declare i32 @llvm.sadd.sat.i32(i32, i32) #0
Expand Down
Loading
Loading