9
9
// Replaces LLVM IR instructions with vector operands (i.e., the frem
10
10
// instruction or calls to LLVM intrinsics) with matching calls to functions
11
11
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12
+ // This happens only when the cost of calling the vector library is not found to
13
+ // be more than the cost of the original instruction.
12
14
//
13
15
// ===----------------------------------------------------------------------===//
14
16
20
22
#include " llvm/Analysis/GlobalsModRef.h"
21
23
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
22
24
#include " llvm/Analysis/TargetLibraryInfo.h"
25
+ #include " llvm/Analysis/TargetTransformInfo.h"
23
26
#include " llvm/Analysis/VectorUtils.h"
24
27
#include " llvm/CodeGen/Passes.h"
25
28
#include " llvm/IR/DerivedTypes.h"
26
29
#include " llvm/IR/IRBuilder.h"
27
30
#include " llvm/IR/InstIterator.h"
31
+ #include " llvm/IR/Instructions.h"
32
+ #include " llvm/IR/IntrinsicInst.h"
28
33
#include " llvm/IR/VFABIDemangler.h"
34
+ #include " llvm/Support/InstructionCost.h"
29
35
#include " llvm/Support/TypeSize.h"
30
36
#include " llvm/Transforms/Utils/ModuleUtils.h"
31
37
@@ -96,15 +102,55 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
96
102
Replacement->copyFastMathFlags (&I);
97
103
}
98
104
105
+ // / Returns whether the vector library call \p TLIFunc costs more than the
106
+ // / original instruction \p I.
107
+ static bool isVeclibCallSlower (const TargetLibraryInfo &TLI,
108
+ const TargetTransformInfo &TTI, Instruction &I,
109
+ VectorType *VectorTy, CallInst *CI,
110
+ Function *TLIFunc) {
111
+ SmallVector<Type *, 4 > OpTypes;
112
+ for (auto &Op : CI ? CI->args () : I.operands ())
113
+ OpTypes.push_back (Op->getType ());
114
+
115
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
116
+ InstructionCost DefaultCost;
117
+ if (CI) {
118
+ FastMathFlags FMF;
119
+ if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
120
+ FMF = FPMO->getFastMathFlags ();
121
+
122
+ SmallVector<const Value *> Args (CI->args ());
123
+ IntrinsicCostAttributes CostAttrs (CI->getIntrinsicID (), VectorTy, Args,
124
+ OpTypes, FMF,
125
+ dyn_cast<IntrinsicInst>(CI));
126
+ DefaultCost = TTI.getIntrinsicInstrCost (CostAttrs, CostKind);
127
+ } else {
128
+ assert ((I.getOpcode () == Instruction::FRem) && " Only FRem is supported" );
129
+ auto Op2Info = TTI.getOperandInfo (I.getOperand (1 ));
130
+ SmallVector<const Value *, 4 > OpValues (I.operand_values ());
131
+ DefaultCost = TTI.getArithmeticInstrCost (
132
+ I.getOpcode (), VectorTy, CostKind,
133
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
134
+ Op2Info, OpValues, &I);
135
+ }
136
+
137
+ InstructionCost VecLibCost =
138
+ TTI.getCallInstrCost (TLIFunc, VectorTy, OpTypes, CostKind);
139
+ return VecLibCost > DefaultCost;
140
+ }
141
+
99
142
// / Returns true when successfully replaced \p I with a suitable function taking
100
- // / vector arguments, based on available mappings in the \p TLI. Currently only
101
- // / works when \p I is a call to vectorized intrinsic or the frem instruction.
143
+ // / vector arguments, based on available mappings in the \p TLI and costs.
144
+ // / Currently only works when \p I is a call to vectorized intrinsic or the frem
145
+ // / instruction.
102
146
static bool replaceWithCallToVeclib (const TargetLibraryInfo &TLI,
147
+ const TargetTransformInfo &TTI,
103
148
Instruction &I) {
104
149
// At the moment VFABI assumes the return type is always widened unless it is
105
150
// a void type.
106
- auto *VTy = dyn_cast<VectorType>(I.getType ());
107
- ElementCount EC (VTy ? VTy->getElementCount () : ElementCount::getFixed (0 ));
151
+ auto *VectorTy = dyn_cast<VectorType>(I.getType ());
152
+ ElementCount EC (VectorTy ? VectorTy->getElementCount ()
153
+ : ElementCount::getFixed (0 ));
108
154
109
155
// Compute the argument types of the corresponding scalar call and the scalar
110
156
// function name. For calls, it additionally finds the function to replace
@@ -125,9 +171,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
125
171
ScalarArgTypes.push_back (VectorArgTy->getElementType ());
126
172
// When return type is void, set EC to the first vector argument, and
127
173
// disallow vector arguments with different ECs.
128
- if (EC.isZero ())
174
+ if (EC.isZero ()) {
129
175
EC = VectorArgTy->getElementCount ();
130
- else if (EC != VectorArgTy->getElementCount ())
176
+ VectorTy = VectorArgTy;
177
+ } else if (EC != VectorArgTy->getElementCount ())
131
178
return false ;
132
179
} else
133
180
// Exit when it is supposed to be a vector argument but it isn't.
@@ -139,8 +186,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
139
186
? Intrinsic::getName (IID, ScalarArgTypes, I.getModule ())
140
187
: Intrinsic::getName (IID).str ();
141
188
} else {
142
- assert (VTy && " Return type must be a vector" );
143
- auto *ScalarTy = VTy ->getScalarType ();
189
+ assert (VectorTy && " Return type must be a vector" );
190
+ auto *ScalarTy = VectorTy ->getScalarType ();
144
191
LibFunc Func;
145
192
if (!TLI.getLibFunc (I.getOpcode (), ScalarTy, Func))
146
193
return false ;
@@ -200,6 +247,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
200
247
Function *TLIFunc = getTLIFunction (I.getModule (), VectorFTy,
201
248
VD->getVectorFnName (), FuncToReplace);
202
249
250
+ if (isVeclibCallSlower (TLI, TTI, I, VectorTy, CI, TLIFunc))
251
+ return false ;
252
+
203
253
replaceWithTLIFunction (I, *OptInfo, TLIFunc);
204
254
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `" << ScalarName
205
255
<< " ` with call to `" << TLIFunc->getName () << " `.\n " );
@@ -220,13 +270,14 @@ static bool isSupportedInstruction(Instruction *I) {
220
270
return false ;
221
271
}
222
272
223
- static bool runImpl (const TargetLibraryInfo &TLI, Function &F) {
273
+ static bool runImpl (const TargetLibraryInfo &TLI,
274
+ const TargetTransformInfo &TTI, Function &F) {
224
275
bool Changed = false ;
225
276
SmallVector<Instruction *> ReplacedCalls;
226
277
for (auto &I : instructions (F)) {
227
278
if (!isSupportedInstruction (&I))
228
279
continue ;
229
- if (replaceWithCallToVeclib (TLI, I)) {
280
+ if (replaceWithCallToVeclib (TLI, TTI, I)) {
230
281
ReplacedCalls.push_back (&I);
231
282
Changed = true ;
232
283
}
@@ -244,14 +295,16 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
244
295
PreservedAnalyses ReplaceWithVeclib::run (Function &F,
245
296
FunctionAnalysisManager &AM) {
246
297
const TargetLibraryInfo &TLI = AM.getResult <TargetLibraryAnalysis>(F);
247
- auto Changed = runImpl (TLI, F);
298
+ const TargetTransformInfo &TTI = AM.getResult <TargetIRAnalysis>(F);
299
+ auto Changed = runImpl (TLI, TTI, F);
248
300
if (Changed) {
249
301
LLVM_DEBUG (dbgs () << " Instructions replaced with vector libraries: "
250
302
<< NumCallsReplaced << " \n " );
251
303
252
304
PreservedAnalyses PA;
253
305
PA.preserveSet <CFGAnalyses>();
254
306
PA.preserve <TargetLibraryAnalysis>();
307
+ PA.preserve <TargetIRAnalysis>();
255
308
PA.preserve <ScalarEvolutionAnalysis>();
256
309
PA.preserve <LoopAccessAnalysis>();
257
310
PA.preserve <DemandedBitsAnalysis>();
@@ -269,13 +322,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
269
322
bool ReplaceWithVeclibLegacy::runOnFunction (Function &F) {
270
323
const TargetLibraryInfo &TLI =
271
324
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI (F);
272
- return runImpl (TLI, F);
325
+ const TargetTransformInfo &TTI =
326
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
327
+ return runImpl (TLI, TTI, F);
273
328
}
274
329
275
330
void ReplaceWithVeclibLegacy::getAnalysisUsage (AnalysisUsage &AU) const {
276
331
AU.setPreservesCFG ();
277
332
AU.addRequired <TargetLibraryInfoWrapperPass>();
333
+ AU.addRequired <TargetTransformInfoWrapperPass>();
278
334
AU.addPreserved <TargetLibraryInfoWrapperPass>();
335
+ AU.addPreserved <TargetTransformInfoWrapperPass>();
279
336
AU.addPreserved <ScalarEvolutionWrapperPass>();
280
337
AU.addPreserved <AAResultsWrapperPass>();
281
338
AU.addPreserved <OptimizationRemarkEmitterWrapperPass>();
0 commit comments