@@ -69,20 +69,20 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
69
69
return TLIFunc;
70
70
}
71
71
72
- // / Replace the call to the vector intrinsic ( \p OldFunc ) with a call to the
73
- // / corresponding function from the vector library ( \p TLIFunc ).
74
- static bool replaceWithTLIFunction (const Module *M, CallInst &CI,
75
- const ElementCount &VecVF, Function *OldFunc,
76
- Function *TLIFunc, FunctionType *VecFTy,
77
- bool IsMasked) {
72
+ // / Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
73
+ // / the corresponding function from the vector library ( \p TLIFunc ).
74
+ static void replaceWithTLIFunction (CallInst &CI, VFInfo &Info,
75
+ Function *TLIFunc, FunctionType *VecFTy) {
78
76
IRBuilder<> IRBuilder (&CI);
79
77
SmallVector<Value *> Args (CI.args ());
80
- if (IsMasked ) {
78
+ if (auto OptMaskpos = Info. getParamIndexForOptionalMask () ) {
81
79
if (Args.size () == VecFTy->getNumParams ())
82
80
static_assert (true && " mask was already in place" );
83
81
84
- auto *MaskTy = VectorType::get (Type::getInt1Ty (M->getContext ()), VecVF);
85
- Args.push_back (Constant::getAllOnesValue (MaskTy));
82
+ auto *MaskTy =
83
+ VectorType::get (Type::getInt1Ty (CI.getContext ()), Info.Shape .VF );
84
+ Args.insert (Args.begin () + OptMaskpos.value (),
85
+ Constant::getAllOnesValue (MaskTy));
86
86
}
87
87
88
88
// Preserve the operand bundles.
@@ -95,26 +95,18 @@ static bool replaceWithTLIFunction(const Module *M, CallInst &CI,
95
95
// Preserve fast math flags for FP math.
96
96
if (isa<FPMathOperator>(Replacement))
97
97
Replacement->copyFastMathFlags (&CI);
98
-
99
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
100
- << OldFunc->getName () << " ` with call to `"
101
- << TLIFunc->getName () << " `.\n " );
102
- ++NumCallsReplaced;
103
- return true ;
104
98
}
105
99
106
- // / Utility method to get the VecDesc, depending on whether there is a TLI
107
- // / mapping, either with or without a mask .
100
+ // / Utility method to get the VecDesc, depending on whether there is such a TLI
101
+ // / mapping, prioritizing a masked version .
108
102
static std::optional<const VecDesc *> getVecDesc (const TargetLibraryInfo &TLI,
109
103
const StringRef &ScalarName,
110
104
const ElementCount &VF) {
111
- const VecDesc *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true );
112
- const VecDesc *VDNoMask = TLI.getVectorMappingInfo (ScalarName, VF, false );
113
- // Invalid when there are both variants (ie masked and unmasked), or none
114
- if ((VDMasked == nullptr ) == (VDNoMask == nullptr ))
115
- return std::nullopt;
116
-
117
- return {VDMasked != nullptr ? VDMasked : VDNoMask};
105
+ if (auto *VDMasked = TLI.getVectorMappingInfo (ScalarName, VF, true ))
106
+ return VDMasked;
107
+ if (auto *VDNoMask = TLI.getVectorMappingInfo (ScalarName, VF, false ))
108
+ return VDNoMask;
109
+ return std::nullopt;
118
110
}
119
111
120
112
// / Returns whether it is able to replace a call to the intrinsic \p CI with a
@@ -146,10 +138,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
146
138
VF = VectorArgTy->getElementCount ();
147
139
else if (VF != VectorArgTy->getElementCount ())
148
140
return false ;
149
- } else {
141
+ } else
150
142
// enters when it is supposed to be a vector argument but it isn't.
151
143
return false ;
152
- }
153
144
}
154
145
155
146
// Try to reconstruct the name for the scalar version of this intrinsic using
@@ -164,44 +155,40 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
164
155
if (!TLI.isFunctionVectorizable (ScalarName))
165
156
return false ;
166
157
158
+ // Try to find the mapping for the scalar version of this intrinsic and the
159
+ // exact vector width of the call operands in the TargetLibraryInfo.
167
160
auto OptVD = getVecDesc (TLI, ScalarName, VF);
168
161
if (!OptVD)
169
162
return false ;
170
163
171
164
const VecDesc *VD = *OptVD;
172
- // Try to find the mapping for the scalar version of this intrinsic and the
173
- // exact vector width of the call operands in the TargetLibraryInfo.
174
- StringRef TLIName = TLI.getVectorizedFunction (ScalarName, VF, VD->isMasked ());
175
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Looking up TLI mapping for `"
176
- << ScalarName << " ` and vector width " << VF << " .\n " );
177
-
178
- // TLI failed to find a correct mapping.
179
- if (TLIName.empty ())
180
- return false ;
181
-
182
- // Find the vector Function and replace the call to the intrinsic with a call
183
- // to the vector library function.
184
- LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI function `" << TLIName
185
- << " `.\n " );
165
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Found TLI mapping from: `" << ScalarName
166
+ << " ` and vector width " << VF << " to: `"
167
+ << VD->getVectorFnName () << " `.\n " );
186
168
169
+ // Replace the call to the intrinsic with a call to the vector library
170
+ // function.
187
171
Type *ScalarRetTy = CI.getType ()->getScalarType ();
188
172
FunctionType *ScalarFTy = FunctionType::get (ScalarRetTy, ScalarTypes, false );
189
173
const std::string MangledName = VD->getVectorFunctionABIVariantString ();
190
174
auto OptInfo = VFABI::tryDemangleForVFABI (MangledName, ScalarFTy);
191
175
if (!OptInfo)
192
176
return false ;
193
177
194
- // get the vector FunctionType
195
- Module *M = CI.getModule ();
196
- auto OptFTy = VFABI::createFunctionType (*OptInfo, ScalarFTy);
197
- if (!OptFTy)
178
+ FunctionType *VectorFTy = VFABI::createFunctionType (*OptInfo, ScalarFTy);
179
+ if (!VectorFTy)
198
180
return false ;
199
181
200
- Function *OldFunc = CI.getCalledFunction ();
201
- FunctionType *VectorFTy = *OptFTy;
202
- Function *TLIFunc = getTLIFunction (M, VectorFTy, OldFunc, TLIName);
203
- return replaceWithTLIFunction (M, CI, OptInfo->Shape .VF , OldFunc, TLIFunc,
204
- VectorFTy, VD->isMasked ());
182
+ Function *FuncToReplace = CI.getCalledFunction ();
183
+ Function *TLIFunc = getTLIFunction (CI.getModule (), VectorFTy, FuncToReplace,
184
+ VD->getVectorFnName ());
185
+ replaceWithTLIFunction (CI, *OptInfo, TLIFunc, VectorFTy);
186
+
187
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `"
188
+ << FuncToReplace->getName () << " ` with call to `"
189
+ << TLIFunc->getName () << " `.\n " );
190
+ ++NumCallsReplaced;
191
+ return true ;
205
192
}
206
193
207
194
static bool runImpl (const TargetLibraryInfo &TLI, Function &F) {
0 commit comments