Skip to content

[SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. #86135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 88 additions & 14 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7056,19 +7056,16 @@ bool BoUpSLP::areAllUsersVectorized(

static std::pair<InstructionCost, InstructionCost>
getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
ArrayRef<Type *> ArgTys) {
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);

// Calculate the cost of the scalar and vector calls.
SmallVector<Type *, 4> VecTys;
for (Use &Arg : CI->args())
VecTys.push_back(
FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
FastMathFlags FMF;
if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
FMF = FPCI->getFastMathFlags();
SmallVector<const Value *> Arguments(CI->args());
IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF,
IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, ArgTys, FMF,
dyn_cast<IntrinsicInst>(CI));
auto IntrinsicCost =
TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
Expand All @@ -7081,8 +7078,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
if (!CI->isNoBuiltin() && VecFunc) {
// Calculate the cost of the vector library call.
// If the corresponding vector call is cheaper, return its cost.
LibCost = TTI->getCallInstrCost(nullptr, VecTy, VecTys,
TTI::TCK_RecipThroughput);
LibCost =
TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput);
}
return {IntrinsicCost, LibCost};
}
Expand Down Expand Up @@ -8508,6 +8505,30 @@ TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const {
return TTI::CastContextHint::None;
}

/// Builds the arguments types vector for the given call instruction with the
/// given \p ID for the specified vector factor.
static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI,
const Intrinsic::ID ID,
const unsigned VF,
unsigned MinBW) {
SmallVector<Type *> ArgTys;
for (auto [Idx, Arg] : enumerate(CI->args())) {
if (ID != Intrinsic::not_intrinsic) {
if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
ArgTys.push_back(Arg->getType());
continue;
}
if (MinBW > 0) {
ArgTys.push_back(FixedVectorType::get(
IntegerType::get(CI->getContext(), MinBW), VF));
continue;
}
}
ArgTys.push_back(FixedVectorType::get(Arg->getType(), VF));
}
return ArgTys;
}

InstructionCost
BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
SmallPtrSetImpl<Value *> &CheckedExtracts) {
Expand Down Expand Up @@ -9074,7 +9095,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
};
auto GetVectorCost = [=](InstructionCost CommonCost) {
auto *CI = cast<CallInst>(VL0);
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
SmallVector<Type *> ArgTys =
buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
It != MinBWs.end() ? It->second.first : 0);
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
};
return GetCostDiff(GetScalarCost, GetVectorCost);
Expand Down Expand Up @@ -12546,7 +12571,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {

Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);

auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
SmallVector<Type *> ArgTys =
buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(),
It != MinBWs.end() ? It->second.first : 0);
auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
VecCallCosts.first <= VecCallCosts.second;

Expand All @@ -12555,16 +12583,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
TysForDecl.push_back(
FixedVectorType::get(CI->getType(), E->Scalars.size()));
TysForDecl.push_back(VecTy);
auto *CEI = cast<CallInst>(VL0);
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
ValueList OpVL;
// Some intrinsics have scalar arguments. This argument should not be
// vectorized.
if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
ScalarArg = CEI->getArgOperand(I);
OpVecs.push_back(CEI->getArgOperand(I));
// if decided to reduce bitwidth of abs intrinsic, it second argument
// must be set false (do not return poison, if value issigned min).
if (ID == Intrinsic::abs && It != MinBWs.end() &&
It->second.first < DL->getTypeSizeInBits(CEI->getType()))
ScalarArg = Builder.getFalse();
OpVecs.push_back(ScalarArg);
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
TysForDecl.push_back(ScalarArg->getType());
continue;
Expand All @@ -12577,10 +12609,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
}
ScalarArg = CEI->getArgOperand(I);
if (cast<VectorType>(OpVec->getType())->getElementType() !=
ScalarArg->getType()) {
ScalarArg->getType() &&
It == MinBWs.end()) {
auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
VecTy->getNumElements());
OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
} else if (It != MinBWs.end()) {
OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
}
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
OpVecs.push_back(OpVec);
Expand Down Expand Up @@ -14324,6 +14359,45 @@ bool BoUpSLP::collectValuesToDemote(
return TryProcessInstruction(I, *ITE, BitWidth, Ops);
}

case Instruction::Call: {
auto *IC = dyn_cast<IntrinsicInst>(I);
if (!IC)
break;
Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
if (ID != Intrinsic::abs && ID != Intrinsic::smin &&
ID != Intrinsic::smax && ID != Intrinsic::umin && ID != Intrinsic::umax)
break;
SmallVector<Value *> Operands(1, I->getOperand(0));
End = 1;
if (ID != Intrinsic::abs) {
Operands.push_back(I->getOperand(1));
End = 2;
}
InstructionCost BestCost =
std::numeric_limits<InstructionCost::CostType>::max();
unsigned BestBitWidth = BitWidth;
unsigned VF = ITE->Scalars.size();
// Choose the best bitwidth based on cost estimations.
auto Checker = [&](unsigned BitWidth, unsigned) {
unsigned MinBW = PowerOf2Ceil(BitWidth);
SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(IC, ID, VF, MinBW);
auto VecCallCosts = getVectorCallCosts(
IC,
FixedVectorType::get(IntegerType::get(IC->getContext(), MinBW), VF),
TTI, TLI, ArgTys);
InstructionCost Cost = std::min(VecCallCosts.first, VecCallCosts.second);
if (Cost < BestCost) {
BestCost = Cost;
BestBitWidth = BitWidth;
}
return false;
};
[[maybe_unused]] bool NeedToExit;
(void)AttemptCheckBitwidth(Checker, NeedToExit);
BitWidth = BestBitWidth;
return TryProcessInstruction(I, *ITE, BitWidth, Operands);
}

// Otherwise, conservatively give up.
default:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ define void @test(ptr %0, i8 %1, i1 %cmp12.i) {
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <8 x i8> [[TMP4]], <8 x i8> poison, <8 x i32> zeroinitializer
; CHECK-NEXT: br label [[PRE:%.*]]
; CHECK: pre:
; CHECK-NEXT: [[TMP6:%.*]] = zext <8 x i8> [[TMP5]] to <8 x i32>
; CHECK-NEXT: [[TMP7:%.*]] = call <8 x i32> @llvm.umax.v8i32(<8 x i32> [[TMP6]], <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>)
; CHECK-NEXT: [[TMP8:%.*]] = trunc <8 x i32> [[TMP7]] to <8 x i8>
; CHECK-NEXT: [[TMP8:%.*]] = call <8 x i8> @llvm.umax.v8i8(<8 x i8> [[TMP5]], <8 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>)
; CHECK-NEXT: [[TMP9:%.*]] = add <8 x i8> [[TMP8]], <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>
; CHECK-NEXT: [[TMP10:%.*]] = select <8 x i1> [[TMP3]], <8 x i8> [[TMP9]], <8 x i8> [[TMP5]]
; CHECK-NEXT: store <8 x i8> [[TMP10]], ptr [[TMP0]], align 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ define void @test() {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
; CHECK-NEXT: [[ADD:%.*]] = zext i2 [[TMP3]] to i32
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[ADD]], 0
; CHECK-NEXT: [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
; CHECK-NEXT: [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
; CHECK-NEXT: [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
; CHECK-NEXT: [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
; CHECK-NEXT: [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
; CHECK-NEXT: [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
; CHECK-NEXT: [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
; CHECK-NEXT: [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
; CHECK-NEXT: [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
; CHECK-NEXT: [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
; CHECK-NEXT: [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 false)
; CHECK-NEXT: store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
; CHECK-NEXT: ret i32 undef
;
Expand Down