Skip to content

[LV] Refactor vector function variant selection to prepare for uniform args #68879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7009,39 +7009,52 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {

// Find the cost of vectorizing the call, if we can find a suitable
// vector variant of the function.
InstructionCost MaskCost = 0;
VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
bool UsesMask = MaskRequired;
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
// If we want an unmasked vector function but can't find one matching the
// VF, maybe we can find vector function that does use a mask and
// synthesize an all-true mask.
if (!VecFunc && !MaskRequired) {
Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
// If we found one, add in the cost of creating a mask
if (VecFunc) {
UsesMask = true;
MaskCost = TTI.getShuffleCost(
TargetTransformInfo::SK_Broadcast,
VectorType::get(IntegerType::getInt1Ty(
VecFunc->getFunctionType()->getContext()),
VF));
}
}
bool UsesMask = false;
VFInfo FuncInfo;
Function *VecFunc = nullptr;
// Search through any available variants for one we can use at this VF.
for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
// Must match requested VF.
if (Info.Shape.VF != VF)
continue;

std::optional<unsigned> MaskPos = std::nullopt;
if (VecFunc && UsesMask) {
for (const VFInfo &Info : VFDatabase::getMappings(*CI))
if (Info.Shape == Shape) {
assert(Info.isMasked() && "Vector function info shape mismatch");
MaskPos = Info.getParamIndexForOptionalMask().value();
// Must take a mask argument if one is required
if (MaskRequired && !Info.isMasked())
continue;

// Check that all parameter kinds are supported
bool ParamsOk = true;
for (VFParameter Param : Info.Shape.Parameters) {
switch (Param.ParamKind) {
case VFParamKind::Vector:
break;
case VFParamKind::GlobalPredicate:
UsesMask = true;
break;
default:
ParamsOk = false;
break;
}
}

if (!ParamsOk)
continue;

assert(MaskPos.has_value() && "Unable to find mask parameter index");
// Found a suitable candidate, stop here.
VecFunc = CI->getModule()->getFunction(Info.VectorName);
FuncInfo = Info;
break;
}

// Add in the cost of synthesizing a mask if one wasn't required.
InstructionCost MaskCost = 0;
if (VecFunc && UsesMask && !MaskRequired)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't functionality related to support of the uniform paramaters, and in my opinion should go as a separate patch with extra tests checking the increased cost.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just refactoring of existing code, not new functionality.

MaskCost = TTI.getShuffleCost(
TargetTransformInfo::SK_Broadcast,
VectorType::get(IntegerType::getInt1Ty(
VecFunc->getFunctionType()->getContext()),
VF));

if (TLI && VecFunc && !CI->isNoBuiltin())
VectorCost =
TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
Expand All @@ -7065,7 +7078,8 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
Decision = CM_IntrinsicCall;
}

setCallWideningDecision(CI, VF, Decision, VecFunc, IID, MaskPos, Cost);
setCallWideningDecision(CI, VF, Decision, VecFunc, IID,
FuncInfo.getParamIndexForOptionalMask(), Cost);
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
"DbgInfoIntrinsic should have been dropped during VPlan construction");
State.setDebugLocFrom(CI.getDebugLoc());

FunctionType *VFTy = nullptr;
if (Variant)
VFTy = Variant->getFunctionType();
for (unsigned Part = 0; Part < State.UF; ++Part) {
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
Expand All @@ -514,12 +517,15 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
for (const auto &I : enumerate(operands())) {
// Some intrinsics have a scalar argument - don't replace it with a
// vector.
// Some vectorized function variants may also take a scalar argument,
// e.g. linear parameters for pointers.
Value *Arg;
if (VectorIntrinsicID == Intrinsic::not_intrinsic ||
!isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))
Arg = State.get(I.value(), Part);
else
if ((VFTy && !VFTy->getParamType(I.index())->isVectorTy()) ||
(VectorIntrinsicID != Intrinsic::not_intrinsic &&
isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())))
Arg = State.get(I.value(), VPIteration(0, 0));
else
Arg = State.get(I.value(), Part);
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
TysForDecl.push_back(Arg->getType());
Args.push_back(Arg);
Expand Down