Skip to content

[AMDGPU] Allow SLP to analyze i8s #113002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,14 @@ class TargetTransformInfo {
Function *F, Type *RetTy, ArrayRef<Type *> Tys,
TTI::TargetCostKind CostKind = TTI::TCK_SizeAndLatency) const;

/// \returns The cost of propagating Type \p DataType through Basic Block /
/// function boundaries. If \p IsCallingConv is specified, then \p DataType is
/// associated with either a function argument or return. Otherwise, \p
/// DataType is used in either a GEP instruction, or spans across BasicBlocks
/// (this is relevant because SelectionDAG builder may, for example, scalarize
/// illegal vectors across blocks, which introduces extract/insert code).
InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) const;

/// \returns The number of pieces into which the provided type must be
/// split during legalization. Zero is returned when the answer is unknown.
unsigned getNumberOfParts(Type *Tp) const;
Expand Down Expand Up @@ -2096,6 +2104,8 @@ class TargetTransformInfo::Concept {
virtual InstructionCost getCallInstrCost(Function *F, Type *RetTy,
ArrayRef<Type *> Tys,
TTI::TargetCostKind CostKind) = 0;
virtual InstructionCost getDataFlowCost(Type *DataType,
bool IsCallingConv) = 0;
virtual unsigned getNumberOfParts(Type *Tp) = 0;
virtual InstructionCost
getAddressComputationCost(Type *Ty, ScalarEvolution *SE, const SCEV *Ptr) = 0;
Expand Down Expand Up @@ -2781,6 +2791,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
TTI::TargetCostKind CostKind) override {
return Impl.getCallInstrCost(F, RetTy, Tys, CostKind);
}
InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) override {
return Impl.getDataFlowCost(DataType, IsCallingConv);
}
unsigned getNumberOfParts(Type *Tp) override {
return Impl.getNumberOfParts(Tp);
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,10 @@ class TargetTransformInfoImplBase {
return 1;
}

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) const {
return 0;
}

// Assume that we have a register of the right size for the type.
unsigned getNumberOfParts(Type *Tp) const { return 1; }

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return 10;
}

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv) {
return 0;
}

unsigned getNumberOfParts(Type *Tp) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
return LT.first.isValid() ? *LT.first.getValue() : 0;
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,13 @@ TargetTransformInfo::getCallInstrCost(Function *F, Type *RetTy,
return Cost;
}

InstructionCost TargetTransformInfo::getDataFlowCost(Type *DataType,
bool IsCallingConv) const {
InstructionCost Cost = TTIImpl->getDataFlowCost(DataType, IsCallingConv);
assert(Cost >= 0 && "TTI should not produce negative costs!");
return Cost;
}

unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const {
return TTIImpl->getNumberOfParts(Tp);
}
Expand Down
33 changes: 30 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,31 @@ bool GCNTTIImpl::hasBranchDivergence(const Function *F) const {
return !F || !ST->isSingleLaneExecution(*F);
}

InstructionCost GCNTTIImpl::getDataFlowCost(Type *DataType,
bool IsCallingConv) {
if (isTypeLegal(DataType) || IsCallingConv)
return BaseT::getDataFlowCost(DataType, IsCallingConv);

return getNumberOfParts(DataType);
}

unsigned GCNTTIImpl::getNumberOfParts(Type *Tp) {
// For certain 8 bit ops, we can pack a v4i8 into a single part
// (e.g. v4i8 shufflevectors -> v_perm v4i8, v4i8). Thus, we
// do not limit the numberOfParts for 8 bit vectors to the
// legalization costs of such. It is left up to other target
// queries (e.g. get*InstrCost) to decide the proper handling
// of 8 bit vectors.
if (FixedVectorType *VTy = dyn_cast<FixedVectorType>(Tp)) {
if (DL.getTypeSizeInBits(VTy->getElementType()) == 8) {
unsigned ElCount = VTy->getElementCount().getFixedValue();
return ElCount / 4;
}
}

return BaseT::getNumberOfParts(Tp);
}

unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const {
// NB: RCID is not an RCID. In fact it is 0 or 1 for scalar or vector
// registers. See getRegisterClassForType for the implementation.
Expand Down Expand Up @@ -337,9 +362,11 @@ unsigned GCNTTIImpl::getMinVectorRegisterBitWidth() const {
unsigned GCNTTIImpl::getMaximumVF(unsigned ElemWidth, unsigned Opcode) const {
if (Opcode == Instruction::Load || Opcode == Instruction::Store)
return 32 * 4 / ElemWidth;
return (ElemWidth == 16 && ST->has16BitInsts()) ? 2
: (ElemWidth == 32 && ST->hasPackedFP32Ops()) ? 2
: 1;

return (ElemWidth == 8) ? 4
: (ElemWidth == 16) ? 2
: (ElemWidth == 32 && ST->hasPackedFP32Ops()) ? 2
: 1;
}

unsigned GCNTTIImpl::getLoadVectorFactor(unsigned VF, unsigned LoadSize,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
return TTI::PSK_FastHardware;
}

unsigned getNumberOfParts(Type *Tp);
unsigned getNumberOfRegisters(unsigned RCID) const;
TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind Vector) const;
unsigned getMinVectorRegisterBitWidth() const;
Expand Down Expand Up @@ -161,6 +162,8 @@ class GCNTTIImpl final : public BasicTTIImplBase<GCNTTIImpl> {
InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);

InstructionCost getDataFlowCost(Type *DataType, bool IsCallingConv);

bool isInlineAsmSourceOfDivergence(const CallInst *CI,
ArrayRef<unsigned> Indices = {}) const;

Expand Down
91 changes: 85 additions & 6 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9044,6 +9044,51 @@ static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
return ArgTys;
}

// The cost model may determine that vectorizing and eliminating a series of
// ExtractElements is beneficial. However, if the input vector is a function
// argument, the calling convention may require extractions in the geneerated
// code. In this scenario, vectorizaino would then not eliminate the
// ExtractElement sequence, but would add additional vectorization code.
// getCCCostFromScalars does the proper accounting for this.
static unsigned getCCCostFromScalars(ArrayRef<Value *> &Scalars,
unsigned ScalarSize,
TargetTransformInfo *TTI) {
SetVector<Value *> ArgRoots;
for (unsigned I = 0; I < ScalarSize; I++) {
auto *Scalar = Scalars[I];
if (!Scalar)
continue;
auto *EE = dyn_cast<ExtractElementInst>(Scalar);
if (!EE)
continue;

auto *Vec = EE->getOperand(0);
if (!Vec->getType()->isVectorTy())
continue;

auto F = EE->getFunction();
auto FoundIt = find_if(
F->args(), [&Vec](Argument &I) { return Vec == cast<Value>(&I); });

if (FoundIt == F->arg_end())
continue;

if (!ArgRoots.contains(Vec))
ArgRoots.insert(Vec);
}

if (!ArgRoots.size())
return 0;

unsigned Cost = 0;
for (auto ArgOp : ArgRoots) {
Cost += TTI->getDataFlowCost(ArgOp->getType(), /*IsCallingConv*/ true)
.getValue()
.value_or(0);
}
return Cost;
}

InstructionCost
BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
SmallPtrSetImpl<Value *> &CheckedExtracts) {
Expand Down Expand Up @@ -9075,15 +9120,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
auto *FinalVecTy = FixedVectorType::get(ScalarTy, EntryVF);

bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty();
InstructionCost CommonCost = getCCCostFromScalars(VL, VL.size(), TTI);
if (E->State == TreeEntry::NeedToGather) {
if (allConstant(VL))
return 0;
return CommonCost;
if (isa<InsertElementInst>(VL[0]))
return InstructionCost::getInvalid();
return processBuildVector<ShuffleCostEstimator, InstructionCost>(
E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
return CommonCost +
processBuildVector<ShuffleCostEstimator, InstructionCost>(
E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts);
}
InstructionCost CommonCost = 0;
SmallVector<int> Mask;
bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
if (!E->ReorderIndices.empty() &&
Expand Down Expand Up @@ -10241,6 +10287,31 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {

InstructionCost C = getEntryCost(&TE, VectorizedVals, CheckedExtracts);
Cost += C;

// Calculate the cost difference of propagating a vector vs series of scalars
// across blocks. This may be nonzero in the case of illegal vectors.
Instruction *VL0 = TE.getMainOp();
bool IsAPhi = VL0 && isa<PHINode>(VL0);
bool HasNextEntry = VL0 && ((I + 1) < VectorizableTree.size());
bool LiveThru = false;
if (HasNextEntry) {
Instruction *VL1 = VectorizableTree[I + 1]->getMainOp();
LiveThru = VL1 && (VL0->getParent() != VL1->getParent());
}
if (IsAPhi || LiveThru) {
VectorType *VTy = dyn_cast<VectorType>(VL0->getType());
Type *ScalarTy = VTy ? VTy->getElementType() : VL0->getType();
if (ScalarTy && isValidElementType(ScalarTy)) {
InstructionCost ScalarDFlow =
TTI->getDataFlowCost(ScalarTy,
/*IsCallingConv*/ false) *
TE.getVectorFactor();
InstructionCost VectorDFlow =
TTI->getDataFlowCost(FixedVectorType::get(ScalarTy, TE.getVectorFactor()), /*IsCallingConv*/ false);
Cost += (VectorDFlow - ScalarDFlow);
}
}

LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle "
<< shortBundleName(TE.Scalars) << ".\n"
<< "SLP: Current total cost = " << Cost << "\n");
Expand All @@ -10257,15 +10328,24 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
for (ExternalUser &EU : ExternalUses) {
// We only add extract cost once for the same scalar.
if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
!ExtractCostCalculated.insert(EU.Scalar).second)
!ExtractCostCalculated.insert(EU.Scalar).second) {
continue;
}

// Uses by ephemeral values are free (because the ephemeral value will be
// removed prior to code generation, and so the extraction will be
// removed as well).
if (EphValues.count(EU.User))
continue;

// Account for any additional costs required by CallingConvention for the
// type.
if (isa_and_nonnull<ReturnInst>(EU.User)) {
Cost +=
TTI->getDataFlowCost(EU.Scalar->getType(), /*IsCallingConv*/ true);
continue;
}

// No extract cost for vector "scalar"
if (isa<FixedVectorType>(EU.Scalar->getType()))
continue;
Expand Down Expand Up @@ -10566,7 +10646,6 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
if (ViewSLPTree)
ViewGraph(this, "SLP" + F->getName(), false, Str);
#endif

return Cost;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,66 @@ bb:
ret <4 x i16> %ins.3
}

define <4 x i8> @uadd_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1) {
; GCN-LABEL: @uadd_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.uadd.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.uadd.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.uadd.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.uadd.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.uadd.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> poison, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3
}

define <4 x i8> @usub_sat_v4i8(<4 x i8> %arg0, <4 x i8> %arg1) {
; GCN-LABEL: @usub_sat_v4i8(
; GCN-NEXT: bb:
; GCN-NEXT: [[TMP0:%.*]] = call <4 x i8> @llvm.usub.sat.v4i8(<4 x i8> [[ARG0:%.*]], <4 x i8> [[ARG1:%.*]])
; GCN-NEXT: ret <4 x i8> [[TMP0]]
;
bb:
%arg0.0 = extractelement <4 x i8> %arg0, i64 0
%arg0.1 = extractelement <4 x i8> %arg0, i64 1
%arg0.2 = extractelement <4 x i8> %arg0, i64 2
%arg0.3 = extractelement <4 x i8> %arg0, i64 3
%arg1.0 = extractelement <4 x i8> %arg1, i64 0
%arg1.1 = extractelement <4 x i8> %arg1, i64 1
%arg1.2 = extractelement <4 x i8> %arg1, i64 2
%arg1.3 = extractelement <4 x i8> %arg1, i64 3
%add.0 = call i8 @llvm.usub.sat.i8(i8 %arg0.0, i8 %arg1.0)
%add.1 = call i8 @llvm.usub.sat.i8(i8 %arg0.1, i8 %arg1.1)
%add.2 = call i8 @llvm.usub.sat.i8(i8 %arg0.2, i8 %arg1.2)
%add.3 = call i8 @llvm.usub.sat.i8(i8 %arg0.3, i8 %arg1.3)
%ins.0 = insertelement <4 x i8> poison, i8 %add.0, i64 0
%ins.1 = insertelement <4 x i8> %ins.0, i8 %add.1, i64 1
%ins.2 = insertelement <4 x i8> %ins.1, i8 %add.2, i64 2
%ins.3 = insertelement <4 x i8> %ins.2, i8 %add.3, i64 3
ret <4 x i8> %ins.3
}

declare i16 @llvm.uadd.sat.i16(i16, i16) #0
declare i16 @llvm.usub.sat.i16(i16, i16) #0
declare i16 @llvm.sadd.sat.i16(i16, i16) #0
declare i16 @llvm.ssub.sat.i16(i16, i16) #0

declare i8 @llvm.uadd.sat.i8(i8, i8) #0
declare i8 @llvm.usub.sat.i8(i8, i8) #0

declare i32 @llvm.uadd.sat.i32(i32, i32) #0
declare i32 @llvm.usub.sat.i32(i32, i32) #0
declare i32 @llvm.sadd.sat.i32(i32, i32) #0
Expand Down
Loading
Loading