Skip to content

[TLI] ReplaceWithVecLib pass uses CostModel #78688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 69 additions & 12 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
// This happens only when the cost of calling the vector library is not found to
// be more than the cost of the original instruction.
//
//===----------------------------------------------------------------------===//

Expand All @@ -20,12 +22,16 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/VFABIDemangler.h"
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

Expand Down Expand Up @@ -96,15 +102,55 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Replacement->copyFastMathFlags(&I);
}

/// Returns whether the vector library call \p TLIFunc costs more than the
/// original instruction \p I.
static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI, Instruction &I,
VectorType *VectorTy, CallInst *CI,
Copy link
Collaborator

@huntergr-arm huntergr-arm Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just pass the instruction and re-cast it here (possibly directly to IntrinsicInst?) instead of passing both.

Function *TLIFunc) {
SmallVector<Type *, 4> OpTypes;
for (auto &Op : CI ? CI->args() : I.operands())
Copy link
Collaborator

@huntergr-arm huntergr-arm Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only need the I.operands(); also consider using Value * instead of auto &

OpTypes.push_back(Op->getType());

TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost DefaultCost;
if (CI) {
FastMathFlags FMF;
if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
FMF = FPMO->getFastMathFlags();

SmallVector<const Value *> Args(CI->args());
IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
OpTypes, FMF,
dyn_cast<IntrinsicInst>(CI));
DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
} else {
assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
SmallVector<const Value *, 4> OpValues(I.operand_values());
DefaultCost = TTI.getArithmeticInstrCost(
I.getOpcode(), VectorTy, CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the arguments after CostKind actually needed? (They have defaults)

Op2Info, OpValues, &I);
}

InstructionCost VecLibCost =
TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
return VecLibCost > DefaultCost;
}

/// Returns true when successfully replaced \p I with a suitable function taking
/// vector arguments, based on available mappings in the \p TLI. Currently only
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
/// vector arguments, based on available mappings in the \p TLI and costs.
/// Currently only works when \p I is a call to vectorized intrinsic or the frem
/// instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI,
Instruction &I) {
// At the moment VFABI assumes the return type is always widened unless it is
// a void type.
auto *VTy = dyn_cast<VectorType>(I.getType());
ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
auto *VectorTy = dyn_cast<VectorType>(I.getType());
ElementCount EC(VectorTy ? VectorTy->getElementCount()
: ElementCount::getFixed(0));

// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
Expand All @@ -125,9 +171,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
if (EC.isZero())
if (EC.isZero()) {
EC = VectorArgTy->getElementCount();
else if (EC != VectorArgTy->getElementCount())
VectorTy = VectorArgTy;
} else if (EC != VectorArgTy->getElementCount())
return false;
} else
// Exit when it is supposed to be a vector argument but it isn't.
Expand All @@ -139,8 +186,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
assert(VTy && "Return type must be a vector");
auto *ScalarTy = VTy->getScalarType();
assert(VectorTy && "Return type must be a vector");
auto *ScalarTy = VectorTy->getScalarType();
LibFunc Func;
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
Expand Down Expand Up @@ -200,6 +247,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);

if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
return false;

replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
Expand All @@ -220,13 +270,14 @@ static bool isSupportedInstruction(Instruction *I) {
return false;
}

static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
static bool runImpl(const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (!isSupportedInstruction(&I))
continue;
if (replaceWithCallToVeclib(TLI, I)) {
if (replaceWithCallToVeclib(TLI, TTI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
}
Expand All @@ -244,14 +295,16 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
FunctionAnalysisManager &AM) {
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto Changed = runImpl(TLI, F);
const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
auto Changed = runImpl(TLI, TTI, F);
if (Changed) {
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
<< NumCallsReplaced << "\n");

PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
PA.preserve<TargetIRAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<LoopAccessAnalysis>();
PA.preserve<DemandedBitsAnalysis>();
Expand All @@ -269,13 +322,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
return runImpl(TLI, F);
const TargetTransformInfo &TTI =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return runImpl(TLI, TTI, F);
}

void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
AU.addPreserved<TargetTransformInfoWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a good idea to add a flag to ReplaceWithVeclib.cpp to override the cost for testing purposes. Then filter for call or frem with the autogenerator, and add a second runline using the flag to make sure we still perform the transformation.

; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {

define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: @frem_f64(
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
Expand Down