@@ -43,24 +43,24 @@ STATISTIC(NumFuncUsedAdded,
43
43
" Number of functions added to `llvm.compiler.used`" );
44
44
45
45
// / Returns a vector Function that it adds to the Module \p M. When an \p
46
- // / OptOldFunc is given, it copies its attributes to the newly created Function.
46
+ // / ScalarFunc is not null, it copies its attributes to the newly created
47
+ // / Function.
47
48
Function *getTLIFunction (Module *M, FunctionType *VectorFTy,
48
- std::optional<Function *> OptOldFunc,
49
- const StringRef TLIName) {
49
+ Function *ScalarFunc, const StringRef TLIName) {
50
50
Function *TLIFunc = M->getFunction (TLIName);
51
51
if (!TLIFunc) {
52
52
TLIFunc =
53
53
Function::Create (VectorFTy, Function::ExternalLinkage, TLIName, *M);
54
- if (OptOldFunc )
55
- TLIFunc->copyAttributesFrom (*OptOldFunc );
54
+ if (ScalarFunc )
55
+ TLIFunc->copyAttributesFrom (ScalarFunc );
56
56
57
57
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Added vector library function `"
58
58
<< TLIName << " ` of type `" << *(TLIFunc->getType ())
59
59
<< " ` to module.\n " );
60
60
61
61
++NumTLIFuncDeclAdded;
62
62
// Add the freshly created function to llvm.compiler.used, similar to as it
63
- // is done in InjectTLIMappings
63
+ // is done in InjectTLIMappings.
64
64
appendToCompilerUsed (*M, {TLIFunc});
65
65
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Adding `" << TLIName
66
66
<< " ` to `@llvm.compiler.used`.\n " );
@@ -72,11 +72,11 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
72
72
// / Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
73
73
// / the corresponding function from the vector library ( \p TLIFunc ).
74
74
static void replaceWithTLIFunction (CallInst &CI, VFInfo &Info,
75
- Function *TLIFunc, FunctionType *VecFTy ) {
75
+ Function *TLIVecFunc ) {
76
76
IRBuilder<> IRBuilder (&CI);
77
77
SmallVector<Value *> Args (CI.args ());
78
78
if (auto OptMaskpos = Info.getParamIndexForOptionalMask ()) {
79
- if (Args.size () == VecFTy ->getNumParams ())
79
+ if (Args.size () == TLIVecFunc-> getFunctionType () ->getNumParams ())
80
80
static_assert (true && " mask was already in place" );
81
81
82
82
auto *MaskTy =
@@ -88,9 +88,7 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
88
88
// Preserve the operand bundles.
89
89
SmallVector<OperandBundleDef, 1 > OpBundles;
90
90
CI.getOperandBundlesAsDefs (OpBundles);
91
- CallInst *Replacement = IRBuilder.CreateCall (TLIFunc, Args, OpBundles);
92
- assert (VecFTy == TLIFunc->getFunctionType () &&
93
- " Expecting function types to be identical" );
91
+ CallInst *Replacement = IRBuilder.CreateCall (TLIVecFunc, Args, OpBundles);
94
92
CI.replaceAllUsesWith (Replacement);
95
93
// Preserve fast math flags for FP math.
96
94
if (isa<FPMathOperator>(Replacement))
@@ -102,10 +100,10 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
102
100
static std::optional<const VecDesc *> getVecDesc (const TargetLibraryInfo &TLI,
103
101
const StringRef &ScalarName,
104
102
const ElementCount &VF) {
105
- if (auto *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true ))
106
- return VDMasked;
107
103
if (auto *VDNoMask = TLI.getVectorMappingInfo (ScalarName, VF, false ))
108
104
return VDNoMask;
105
+ if (auto *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true ))
106
+ return VDMasked;
109
107
return std::nullopt;
110
108
}
111
109
@@ -117,20 +115,20 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
117
115
return false ;
118
116
119
117
auto IntrinsicID = CI.getCalledFunction ()->getIntrinsicID ();
120
- // Replacement is only performed for intrinsic functions
118
+ // Replacement is only performed for intrinsic functions.
121
119
if (IntrinsicID == Intrinsic::not_intrinsic)
122
120
return false ;
123
121
124
122
// Convert vector arguments to scalar type and check that all vector operands
125
123
// have identical vector width.
126
124
ElementCount VF = ElementCount::getFixed (0 );
127
- SmallVector<Type *> ScalarTypes ;
125
+ SmallVector<Type *> ScalarArgTypes ;
128
126
for (auto Arg : enumerate(CI.args ())) {
129
127
auto *ArgTy = Arg.value ()->getType ();
130
128
if (isVectorIntrinsicWithScalarOpAtArg (IntrinsicID, Arg.index ())) {
131
- ScalarTypes .push_back (ArgTy);
129
+ ScalarArgTypes .push_back (ArgTy);
132
130
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
133
- ScalarTypes .push_back (ArgTy->getScalarType ());
131
+ ScalarArgTypes .push_back (ArgTy->getScalarType ());
134
132
// Disallow vector arguments with different VFs. When processing the first
135
133
// vector argument, store it's VF, and for the rest ensure that they match
136
134
// it.
@@ -139,15 +137,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
139
137
else if (VF != VectorArgTy->getElementCount ())
140
138
return false ;
141
139
} else
142
- // enters when it is supposed to be a vector argument but it isn't.
140
+ // Exit when it is supposed to be a vector argument but it isn't.
143
141
return false ;
144
142
}
145
143
146
144
// Try to reconstruct the name for the scalar version of this intrinsic using
147
145
// the intrinsic ID and the argument types converted to scalar above.
148
146
std::string ScalarName =
149
147
(Intrinsic::isOverloaded (IntrinsicID)
150
- ? Intrinsic::getName (IntrinsicID, ScalarTypes , CI.getModule ())
148
+ ? Intrinsic::getName (IntrinsicID, ScalarArgTypes , CI.getModule ())
151
149
: Intrinsic::getName (IntrinsicID).str ());
152
150
153
151
// The TargetLibraryInfo does not contain a vectorized version of the scalar
@@ -169,7 +167,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
169
167
// Replace the call to the intrinsic with a call to the vector library
170
168
// function.
171
169
Type *ScalarRetTy = CI.getType ()->getScalarType ();
172
- FunctionType *ScalarFTy = FunctionType::get (ScalarRetTy, ScalarTypes, false );
170
+ FunctionType *ScalarFTy =
171
+ FunctionType::get (ScalarRetTy, ScalarArgTypes, /* isVarArg*/ false );
173
172
const std::string MangledName = VD->getVectorFunctionABIVariantString ();
174
173
auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
175
174
if (!OptInfo)
@@ -182,7 +181,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
182
181
Function *FuncToReplace = CI.getCalledFunction ();
183
182
Function *TLIFunc = getTLIFunction (CI.getModule (), VectorFTy, FuncToReplace,
184
183
VD->getVectorFnName ());
185
- replaceWithTLIFunction (CI, *OptInfo, TLIFunc, VectorFTy );
184
+ replaceWithTLIFunction (CI, *OptInfo, TLIFunc);
186
185
187
186
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
188
187
<< FuncToReplace->getName () << " ` with call to `"
0 commit comments