15
15
#include " llvm/CodeGen/ReplaceWithVeclib.h"
16
16
#include " llvm/ADT/STLExtras.h"
17
17
#include " llvm/ADT/Statistic.h"
18
+ #include " llvm/ADT/StringRef.h"
18
19
#include " llvm/Analysis/DemandedBits.h"
19
20
#include " llvm/Analysis/GlobalsModRef.h"
20
21
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
21
22
#include " llvm/Analysis/TargetLibraryInfo.h"
22
23
#include " llvm/Analysis/VectorUtils.h"
23
24
#include " llvm/CodeGen/Passes.h"
25
+ #include " llvm/IR/DerivedTypes.h"
24
26
#include " llvm/IR/IRBuilder.h"
25
27
#include " llvm/IR/InstIterator.h"
28
+ #include " llvm/Support/TypeSize.h"
26
29
#include " llvm/Transforms/Utils/ModuleUtils.h"
30
+ #include < optional>
27
31
28
32
using namespace llvm ;
29
33
@@ -38,138 +42,166 @@ STATISTIC(NumTLIFuncDeclAdded,
38
42
STATISTIC (NumFuncUsedAdded,
39
43
" Number of functions added to `llvm.compiler.used`" );
40
44
41
- static bool replaceWithTLIFunction (CallInst &CI, const StringRef TLIName) {
42
- Module *M = CI.getModule ();
43
-
44
- Function *OldFunc = CI.getCalledFunction ();
45
-
46
- // Check if the vector library function is already declared in this module,
47
- // otherwise insert it.
45
+ // / Returns a vector Function that it adds to the Module \p M. When an \p
46
+ // / OptOldFunc is given, it copies its attributes to the newly created Function.
47
+ Function *getTLIFunction (Module *M, FunctionType *VectorFTy,
48
+ std::optional<Function *> OptOldFunc,
49
+ const StringRef TLIName) {
48
50
Function *TLIFunc = M->getFunction (TLIName);
49
51
if (!TLIFunc) {
50
- TLIFunc = Function::Create (OldFunc->getFunctionType (),
51
- Function::ExternalLinkage, TLIName, *M);
52
- TLIFunc->copyAttributesFrom (OldFunc);
52
+ TLIFunc =
53
+ Function::Create (VectorFTy, Function::ExternalLinkage, TLIName, *M);
54
+ if (OptOldFunc)
55
+ TLIFunc->copyAttributesFrom (*OptOldFunc);
53
56
54
57
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Added vector library function `"
55
58
<< TLIName << " ` of type `" << *(TLIFunc->getType ())
56
59
<< " ` to module.\n " );
57
60
58
61
++NumTLIFuncDeclAdded;
59
-
60
- // Add the freshly created function to llvm.compiler.used,
61
- // similar to as it is done in InjectTLIMappings
62
+ // Add the freshly created function to llvm.compiler.used, similar to as it
63
+ // is done in InjectTLIMappings
62
64
appendToCompilerUsed (*M, {TLIFunc});
63
-
64
65
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Adding `" << TLIName
65
66
<< " ` to `@llvm.compiler.used`.\n " );
66
67
++NumFuncUsedAdded;
67
68
}
69
+ return TLIFunc;
70
+ }
68
71
69
- // Replace the call to the vector intrinsic with a call
70
- // to the corresponding function from the vector library.
72
+ // / Replace the call to the vector intrinsic ( \p OldFunc ) with a call to the
73
+ // / corresponding function from the vector library ( \p TLIFunc ).
74
+ static bool replaceWithTLIFunction (const Module *M, CallInst &CI,
75
+ const ElementCount &VecVF, Function *OldFunc,
76
+ Function *TLIFunc, FunctionType *VecFTy,
77
+ bool IsMasked) {
71
78
IRBuilder<> IRBuilder (&CI);
72
79
SmallVector<Value *> Args (CI.args ());
80
+ if (IsMasked) {
81
+ if (Args.size () == VecFTy->getNumParams ())
82
+ static_assert (true && " mask was already in place" );
83
+
84
+ auto *MaskTy = VectorType::get (Type::getInt1Ty (M->getContext ()), VecVF);
85
+ Args.push_back (Constant::getAllOnesValue (MaskTy));
86
+ }
87
+
73
88
// Preserve the operand bundles.
74
89
SmallVector<OperandBundleDef, 1 > OpBundles;
75
90
CI.getOperandBundlesAsDefs (OpBundles);
76
91
CallInst *Replacement = IRBuilder.CreateCall (TLIFunc, Args, OpBundles);
77
- assert (OldFunc-> getFunctionType () == TLIFunc->getFunctionType () &&
92
+ assert (VecFTy == TLIFunc->getFunctionType () &&
78
93
" Expecting function types to be identical" );
79
94
CI.replaceAllUsesWith (Replacement);
80
- if (isa<FPMathOperator>(Replacement)) {
81
- // Preserve fast math flags for FP math.
95
+ // Preserve fast math flags for FP math.
96
+ if (isa<FPMathOperator>(Replacement))
82
97
Replacement->copyFastMathFlags (&CI);
83
- }
84
98
85
99
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
86
- << OldFunc->getName () << " ` with call to `" << TLIName
87
- << " `.\n " );
100
+ << OldFunc->getName () << " ` with call to `"
101
+ << TLIFunc-> getName () << " `.\n " );
88
102
++NumCallsReplaced;
89
103
return true ;
90
104
}
91
105
106
+ // / Utility method to get the VecDesc, depending on whether there is a TLI
107
+ // / mapping, either with or without a mask.
108
+ static std::optional<const VecDesc *> getVecDesc (const TargetLibraryInfo &TLI,
109
+ const StringRef &ScalarName,
110
+ const ElementCount &VF) {
111
+ const VecDesc *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true );
112
+ const VecDesc *VDNoMask = TLI.getVectorMappingInfo (ScalarName, VF, false );
113
+ // Invalid when there are both variants (ie masked and unmasked), or none
114
+ if ((VDMasked == nullptr ) == (VDNoMask == nullptr ))
115
+ return std::nullopt;
116
+
117
+ return {VDMasked != nullptr ? VDMasked : VDNoMask};
118
+ }
119
+
120
+ // / Returns whether it is able to replace a call to the intrinsic \p CI with a
121
+ // / TLI mapped call.
92
122
static bool replaceWithCallToVeclib (const TargetLibraryInfo &TLI,
93
123
CallInst &CI) {
94
- if (!CI.getCalledFunction ()) {
124
+ if (!CI.getCalledFunction ())
95
125
return false ;
96
- }
97
126
98
127
auto IntrinsicID = CI.getCalledFunction ()->getIntrinsicID ();
99
- if (IntrinsicID == Intrinsic::not_intrinsic) {
100
- // Replacement is only performed for intrinsic functions
128
+ // Replacement is only performed for intrinsic functions
129
+ if (IntrinsicID == Intrinsic::not_intrinsic)
101
130
return false ;
102
- }
103
131
104
- // Convert vector arguments to scalar type and check that
105
- // all vector operands have identical vector width.
132
+ // Convert vector arguments to scalar type and check that all vector operands
133
+ // have identical vector width.
106
134
ElementCount VF = ElementCount::getFixed (0 );
107
135
SmallVector<Type *> ScalarTypes;
108
- bool MayBeMasked = false ;
109
136
for (auto Arg : enumerate(CI.args ())) {
110
- auto *ArgType = Arg.value ()->getType ();
111
- // Vector calls to intrinsics can still have
112
- // scalar operands for specific arguments.
137
+ auto *ArgTy = Arg.value ()->getType ();
113
138
if (isVectorIntrinsicWithScalarOpAtArg (IntrinsicID, Arg.index ())) {
114
- ScalarTypes.push_back (ArgType);
115
- } else {
116
- // The argument in this place should be a vector if
117
- // this is a call to a vector intrinsic.
118
- auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
119
- if (!VectorArgTy) {
120
- // The argument is not a vector, do not perform
121
- // the replacement.
122
- return false ;
123
- }
124
- ElementCount NumElements = VectorArgTy->getElementCount ();
125
- if (NumElements.isScalable ())
126
- MayBeMasked = true ;
127
-
128
- // The different arguments differ in vector size.
129
- if (VF.isNonZero () && VF != NumElements)
139
+ ScalarTypes.push_back (ArgTy);
140
+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
141
+ ScalarTypes.push_back (ArgTy->getScalarType ());
142
+ // Disallow vector arguments with different VFs. When processing the first
143
+ // vector argument, store it's VF, and for the rest ensure that they match
144
+ // it.
145
+ if (VF.isZero ())
146
+ VF = VectorArgTy->getElementCount ();
147
+ else if (VF != VectorArgTy->getElementCount ())
130
148
return false ;
131
- VF = NumElements;
132
- ScalarTypes.push_back (VectorArgTy->getElementType ());
149
+ } else {
150
+ // enters when it is supposed to be a vector argument but it isn't.
151
+ return false ;
133
152
}
134
153
}
135
154
136
- // Try to reconstruct the name for the scalar version of this
137
- // intrinsic using the intrinsic ID and the argument types
138
- // converted to scalar above.
139
- std::string ScalarName;
140
- if (Intrinsic::isOverloaded (IntrinsicID)) {
141
- ScalarName = Intrinsic::getName (IntrinsicID, ScalarTypes, CI.getModule ());
142
- } else {
143
- ScalarName = Intrinsic::getName (IntrinsicID).str ();
144
- }
155
+ // Try to reconstruct the name for the scalar version of this intrinsic using
156
+ // the intrinsic ID and the argument types converted to scalar above.
157
+ std::string ScalarName =
158
+ (Intrinsic::isOverloaded (IntrinsicID)
159
+ ? Intrinsic::getName (IntrinsicID, ScalarTypes, CI.getModule ())
160
+ : Intrinsic::getName (IntrinsicID).str ());
145
161
146
- if (!TLI. isFunctionVectorizable (ScalarName)) {
147
- // The TargetLibraryInfo does not contain a vectorized version of
148
- // the scalar function.
162
+ // The TargetLibraryInfo does not contain a vectorized version of the scalar
163
+ // function.
164
+ if (!TLI. isFunctionVectorizable (ScalarName))
149
165
return false ;
150
- }
151
166
152
- // Assume it has a mask when that is a possibility and has no mapping for
153
- // a Non-Masked variant.
154
- const bool IsMasked =
155
- MayBeMasked && !TLI. getVectorMappingInfo (ScalarName, VF, false );
156
- // Try to find the mapping for the scalar version of this intrinsic
157
- // and the exact vector width of the call operands in the
158
- // TargetLibraryInfo.
159
- StringRef TLIName = TLI.getVectorizedFunction (ScalarName, VF, IsMasked );
167
+ auto OptVD = getVecDesc (TLI, ScalarName, VF);
168
+ if (!OptVD)
169
+ return false ;
170
+
171
+ const VecDesc *VD = *OptVD;
172
+ // Try to find the mapping for the scalar version of this intrinsic and the
173
+ // exact vector width of the call operands in the TargetLibraryInfo.
174
+ StringRef TLIName = TLI.getVectorizedFunction (ScalarName, VF, VD-> isMasked () );
160
175
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Looking up TLI mapping for `"
161
176
<< ScalarName << " ` and vector width " << VF << " .\n " );
162
177
163
- if (!TLIName.empty ()) {
164
- // Found the correct mapping in the TargetLibraryInfo,
165
- // replace the call to the intrinsic with a call to
166
- // the vector library function.
167
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI function `" << TLIName
168
- << " `.\n " );
169
- return replaceWithTLIFunction (CI, TLIName);
170
- }
178
+ // TLI failed to find a correct mapping.
179
+ if (TLIName.empty ())
180
+ return false ;
171
181
172
- return false ;
182
+ // Find the vector Function and replace the call to the intrinsic with a call
183
+ // to the vector library function.
184
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI function `" << TLIName
185
+ << " `.\n " );
186
+
187
+ Type *ScalarRetTy = CI.getType ()->getScalarType ();
188
+ FunctionType *ScalarFTy = FunctionType::get (ScalarRetTy, ScalarTypes, false );
189
+ const std::string MangledName = VD->getVectorFunctionABIVariantString ();
190
+ auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
191
+ if (!OptInfo)
192
+ return false ;
193
+
194
+ // get the vector FunctionType
195
+ Module *M = CI.getModule ();
196
+ auto OptFTy = VFABI::createFunctionType (*OptInfo, ScalarFTy);
197
+ if (!OptFTy)
198
+ return false ;
199
+
200
+ Function *OldFunc = CI.getCalledFunction ();
201
+ FunctionType *VectorFTy = *OptFTy;
202
+ Function *TLIFunc = getTLIFunction (M, VectorFTy, OldFunc, TLIName);
203
+ return replaceWithTLIFunction (M, CI, OptInfo->Shape .VF , OldFunc, TLIFunc,
204
+ VectorFTy, VD->isMasked ());
173
205
}
174
206
175
207
static bool runImpl (const TargetLibraryInfo &TLI, Function &F) {
@@ -185,9 +217,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
185
217
}
186
218
// Erase the calls to the intrinsics that have been replaced
187
219
// with calls to the vector library.
188
- for (auto *CI : ReplacedCalls) {
220
+ for (auto *CI : ReplacedCalls)
189
221
CI->eraseFromParent ();
190
- }
191
222
return Changed;
192
223
}
193
224
@@ -207,10 +238,10 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
207
238
PA.preserve <DemandedBitsAnalysis>();
208
239
PA.preserve <OptimizationRemarkEmitterAnalysis>();
209
240
return PA;
210
- } else {
211
- // The pass did not replace any calls, hence it preserves all analyses.
212
- return PreservedAnalyses::all ();
213
241
}
242
+
243
+ // The pass did not replace any calls, hence it preserves all analyses.
244
+ return PreservedAnalyses::all ();
214
245
}
215
246
216
247
// //////////////////////////////////////////////////////////////////////////////
0 commit comments