Skip to content

Commit 2a85ed1

Browse files
[TLI] ReplaceWithVecLib pass uses CostModel
Pass replace-with-veclib only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.
1 parent 3c246ef commit 2a85ed1

File tree

3 files changed

+71
-14
lines changed

3 files changed

+71
-14
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// Replaces LLVM IR instructions with vector operands (i.e., the frem
1010
// instruction or calls to LLVM intrinsics) with matching calls to functions
1111
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12+
// This happens only when the cost of calling the vector library is not found to
13+
// be more than the cost of the original instruction.
1214
//
1315
//===----------------------------------------------------------------------===//
1416

@@ -20,12 +22,16 @@
2022
#include "llvm/Analysis/GlobalsModRef.h"
2123
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
2224
#include "llvm/Analysis/TargetLibraryInfo.h"
25+
#include "llvm/Analysis/TargetTransformInfo.h"
2326
#include "llvm/Analysis/VectorUtils.h"
2427
#include "llvm/CodeGen/Passes.h"
2528
#include "llvm/IR/DerivedTypes.h"
2629
#include "llvm/IR/IRBuilder.h"
2730
#include "llvm/IR/InstIterator.h"
31+
#include "llvm/IR/Instructions.h"
32+
#include "llvm/IR/IntrinsicInst.h"
2833
#include "llvm/IR/VFABIDemangler.h"
34+
#include "llvm/Support/InstructionCost.h"
2935
#include "llvm/Support/TypeSize.h"
3036
#include "llvm/Transforms/Utils/ModuleUtils.h"
3137

@@ -96,15 +102,55 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
96102
Replacement->copyFastMathFlags(&I);
97103
}
98104

105+
/// Returns whether the vector library call \p TLIFunc costs more than the
106+
/// original instruction \p I.
107+
static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
108+
const TargetTransformInfo &TTI, Instruction &I,
109+
VectorType *VectorTy, CallInst *CI,
110+
Function *TLIFunc) {
111+
SmallVector<Type *, 4> OpTypes;
112+
for (auto &Op : CI ? CI->args() : I.operands())
113+
OpTypes.push_back(Op->getType());
114+
115+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
116+
InstructionCost DefaultCost;
117+
if (CI) {
118+
FastMathFlags FMF;
119+
if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
120+
FMF = FPMO->getFastMathFlags();
121+
122+
SmallVector<const Value *> Args(CI->args());
123+
IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
124+
OpTypes, FMF,
125+
dyn_cast<IntrinsicInst>(CI));
126+
DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
127+
} else {
128+
assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
129+
auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
130+
SmallVector<const Value *, 4> OpValues(I.operand_values());
131+
DefaultCost = TTI.getArithmeticInstrCost(
132+
I.getOpcode(), VectorTy, CostKind,
133+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
134+
Op2Info, OpValues, &I);
135+
}
136+
137+
InstructionCost VecLibCost =
138+
TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
139+
return VecLibCost > DefaultCost;
140+
}
141+
99142
/// Returns true when successfully replaced \p I with a suitable function taking
100-
/// vector arguments, based on available mappings in the \p TLI. Currently only
101-
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
143+
/// vector arguments, based on available mappings in the \p TLI and costs.
144+
/// Currently only works when \p I is a call to vectorized intrinsic or the frem
145+
/// instruction.
102146
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
147+
const TargetTransformInfo &TTI,
103148
Instruction &I) {
104149
// At the moment VFABI assumes the return type is always widened unless it is
105150
// a void type.
106-
auto *VTy = dyn_cast<VectorType>(I.getType());
107-
ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
151+
auto *VectorTy = dyn_cast<VectorType>(I.getType());
152+
ElementCount EC(VectorTy ? VectorTy->getElementCount()
153+
: ElementCount::getFixed(0));
108154

109155
// Compute the argument types of the corresponding scalar call and the scalar
110156
// function name. For calls, it additionally finds the function to replace
@@ -125,9 +171,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
125171
ScalarArgTypes.push_back(VectorArgTy->getElementType());
126172
// When return type is void, set EC to the first vector argument, and
127173
// disallow vector arguments with different ECs.
128-
if (EC.isZero())
174+
if (EC.isZero()) {
129175
EC = VectorArgTy->getElementCount();
130-
else if (EC != VectorArgTy->getElementCount())
176+
VectorTy = VectorArgTy;
177+
} else if (EC != VectorArgTy->getElementCount())
131178
return false;
132179
} else
133180
// Exit when it is supposed to be a vector argument but it isn't.
@@ -139,8 +186,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
139186
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
140187
: Intrinsic::getName(IID).str();
141188
} else {
142-
assert(VTy && "Return type must be a vector");
143-
auto *ScalarTy = VTy->getScalarType();
189+
assert(VectorTy && "Return type must be a vector");
190+
auto *ScalarTy = VectorTy->getScalarType();
144191
LibFunc Func;
145192
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
146193
return false;
@@ -200,6 +247,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
200247
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
201248
VD->getVectorFnName(), FuncToReplace);
202249

250+
if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
251+
return false;
252+
203253
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
204254
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
205255
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,13 +270,14 @@ static bool isSupportedInstruction(Instruction *I) {
220270
return false;
221271
}
222272

223-
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
273+
static bool runImpl(const TargetLibraryInfo &TLI,
274+
const TargetTransformInfo &TTI, Function &F) {
224275
bool Changed = false;
225276
SmallVector<Instruction *> ReplacedCalls;
226277
for (auto &I : instructions(F)) {
227278
if (!isSupportedInstruction(&I))
228279
continue;
229-
if (replaceWithCallToVeclib(TLI, I)) {
280+
if (replaceWithCallToVeclib(TLI, TTI, I)) {
230281
ReplacedCalls.push_back(&I);
231282
Changed = true;
232283
}
@@ -244,14 +295,16 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
244295
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
245296
FunctionAnalysisManager &AM) {
246297
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
247-
auto Changed = runImpl(TLI, F);
298+
const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
299+
auto Changed = runImpl(TLI, TTI, F);
248300
if (Changed) {
249301
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
250302
<< NumCallsReplaced << "\n");
251303

252304
PreservedAnalyses PA;
253305
PA.preserveSet<CFGAnalyses>();
254306
PA.preserve<TargetLibraryAnalysis>();
307+
PA.preserve<TargetIRAnalysis>();
255308
PA.preserve<ScalarEvolutionAnalysis>();
256309
PA.preserve<LoopAccessAnalysis>();
257310
PA.preserve<DemandedBitsAnalysis>();
@@ -269,13 +322,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
269322
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
270323
const TargetLibraryInfo &TLI =
271324
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
272-
return runImpl(TLI, F);
325+
const TargetTransformInfo &TTI =
326+
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
327+
return runImpl(TLI, TTI, F);
273328
}
274329

275330
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
276331
AU.setPreservesCFG();
277332
AU.addRequired<TargetLibraryInfoWrapperPass>();
333+
AU.addRequired<TargetTransformInfoWrapperPass>();
278334
AU.addPreserved<TargetLibraryInfoWrapperPass>();
335+
AU.addPreserved<TargetTransformInfoWrapperPass>();
279336
AU.addPreserved<ScalarEvolutionWrapperPass>();
280337
AU.addPreserved<AAResultsWrapperPass>();
281338
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();

llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
428428
define <2 x double> @frem_f64(<2 x double> %in) {
429429
; CHECK-LABEL: define <2 x double> @frem_f64
430430
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
431-
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
431+
; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
432432
; CHECK-NEXT: ret <2 x double> [[TMP1]]
433433
;
434434
%1= frem <2 x double> %in, %in

llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
386386

387387
define <2 x double> @frem_f64(<2 x double> %in) {
388388
; CHECK-LABEL: @frem_f64(
389-
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
389+
; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
390390
; CHECK-NEXT: ret <2 x double> [[TMP1]]
391391
;
392392
%1= frem <2 x double> %in, %in

0 commit comments

Comments
 (0)