@@ -108,15 +108,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
108
108
// Compute the argument types of the corresponding scalar call and the scalar
109
109
// function name. For calls, it additionally finds the function to replace
110
110
// and checks that all vector operands match the previously found EC.
111
- SmallVector<Type *, 8 > ScalarArgTypes;
111
+ SmallVector<Type *, 8 > ScalarArgTypes, OrigArgTypes ;
112
112
std::string ScalarName;
113
113
Function *FuncToReplace = nullptr ;
114
- if (auto *CI = dyn_cast<CallInst>(&I)) {
114
+ auto *CI = dyn_cast<CallInst>(&I);
115
+ if (CI) {
115
116
FuncToReplace = CI->getCalledFunction ();
116
117
Intrinsic::ID IID = FuncToReplace->getIntrinsicID ();
117
118
assert (IID != Intrinsic::not_intrinsic && " Not an intrinsic" );
118
119
for (auto Arg : enumerate(CI->args ())) {
119
120
auto *ArgTy = Arg.value ()->getType ();
121
+ OrigArgTypes.push_back (ArgTy);
120
122
if (isVectorIntrinsicWithScalarOpAtArg (IID, Arg.index ())) {
121
123
ScalarArgTypes.push_back (ArgTy);
122
124
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
@@ -174,6 +176,24 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
174
176
175
177
Function *TLIFunc = getTLIFunction (I.getModule (), VectorFTy,
176
178
VD->getVectorFnName (), FuncToReplace);
179
+
180
+ // For calls, bail out when their arguments do not match with the TLI mapping.
181
+ if (CI) {
182
+ int IdxNonPred = 0 ;
183
+ for (auto [OrigTy, VFParam] :
184
+ zip (OrigArgTypes, OptInfo->Shape .Parameters )) {
185
+ if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
186
+ continue ;
187
+ ++IdxNonPred;
188
+ if (OrigTy->isVectorTy () != (VFParam.ParamKind == VFParamKind::Vector)) {
189
+ LLVM_DEBUG (dbgs () << DEBUG_TYPE
190
+ << " : Will not replace: wrong type at index: "
191
+ << IdxNonPred << " : " << *OrigTy << " \n " );
192
+ return false ;
193
+ }
194
+ }
195
+ }
196
+
177
197
replaceWithTLIFunction (I, *OptInfo, TLIFunc);
178
198
LLVM_DEBUG (dbgs () << DEBUG_TYPE << " : Replaced call to `" << ScalarName
179
199
<< " ` with call to `" << TLIFunc->getName () << " `.\n " );
0 commit comments