8
8
//
9
9
// Replaces LLVM IR instructions with vector operands (i.e., the frem
10
10
// instruction or calls to LLVM intrinsics) with matching calls to functions
11
- // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface
11
+ // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12
12
//
13
13
// ===----------------------------------------------------------------------===//
14
14
@@ -70,7 +70,7 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
70
70
}
71
71
72
72
// / Replace the instruction \p I with a call to the corresponding function from
73
- // / the vector library ( \p TLIVecFunc ).
73
+ // / the vector library (\p TLIVecFunc).
74
74
static void replaceWithTLIFunction (Instruction &I, VFInfo &Info,
75
75
Function *TLIVecFunc) {
76
76
IRBuilder<> IRBuilder (&I);
@@ -100,26 +100,26 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
100
100
// / works when \p I is a call to vectorized intrinsic or the frem instruction.
101
101
static bool replaceWithCallToVeclib (const TargetLibraryInfo &TLI,
102
102
Instruction &I) {
103
+ auto *VTy = dyn_cast<VectorType>(I.getType ());
104
+ ElementCount EC (VTy ? VTy->getElementCount () : ElementCount::getFixed (0 ));
105
+ // Compute the argument types of the corresponding scalar call and the scalar
106
+ // function name. For calls, it additionally finds the function to replace
107
+ // and checks that all vector operands match the previously found EC.
108
+ SmallVector<Type *, 8 > ScalarArgTypes;
103
109
std::string ScalarName;
104
- ElementCount EC = ElementCount::getFixed (0 );
105
110
Function *FuncToReplace = nullptr ;
106
- SmallVector<Type *, 8 > ScalarArgTypes;
107
- // Compute the argument types of the corresponding scalar call, the scalar
108
- // function name, and EC. For calls, it additionally checks if in the vector
109
- // call, all vector operands have the same EC.
110
111
if (auto *CI = dyn_cast<CallInst>(&I)) {
111
- Intrinsic::ID IID = CI->getCalledFunction ()->getIntrinsicID ();
112
- assert (IID != Intrinsic::not_intrinsic && " Not an intrinsic" );
113
112
FuncToReplace = CI->getCalledFunction ();
113
+ Intrinsic::ID IID = FuncToReplace->getIntrinsicID ();
114
+ assert (IID != Intrinsic::not_intrinsic && " Not an intrinsic" );
114
115
for (auto Arg : enumerate(CI->args ())) {
115
116
auto *ArgTy = Arg.value ()->getType ();
116
117
if (isVectorIntrinsicWithScalarOpAtArg (IID, Arg.index ())) {
117
118
ScalarArgTypes.push_back (ArgTy);
118
119
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
119
120
ScalarArgTypes.push_back (VectorArgTy->getElementType ());
120
- // Disallow vector arguments with different VFs. When processing the
121
- // first vector argument, store it's VF, and for the rest ensure that
122
- // they match it.
121
+ // When return type is void, set EC to the first vector argument, and
122
+ // disallow vector arguments with different ECs.
123
123
if (EC.isZero ())
124
124
EC = VectorArgTy->getElementCount ();
125
125
else if (EC != VectorArgTy->getElementCount ())
@@ -134,15 +134,13 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
134
134
? Intrinsic::getName (IID, ScalarArgTypes, I.getModule ())
135
135
: Intrinsic::getName (IID).str ();
136
136
} else {
137
- assert (I.getType ()->isVectorTy () && " Instruction must use vectors" );
137
+ assert (VTy && " Return type must be a vector" );
138
+ auto *ScalarTy = VTy->getScalarType ();
138
139
LibFunc Func;
139
- auto *ScalarTy = I.getType ()->getScalarType ();
140
140
if (!TLI.getLibFunc (I.getOpcode (), ScalarTy, Func))
141
141
return false ;
142
142
ScalarName = TLI.getName (Func);
143
143
ScalarArgTypes = {ScalarTy, ScalarTy};
144
- if (auto *VTy = dyn_cast<VectorType>(I.getType ()))
145
- EC = VTy->getElementCount ();
146
144
}
147
145
148
146
// Try to find the mapping for the scalar version of this intrinsic and the
@@ -180,13 +178,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
180
178
return true ;
181
179
}
182
180
183
- // / Supported instructions \p I are either frem or CallInsts to intrinsics.
181
+ // / Supported instruction \p I must be a vectorized frem or a call to an
182
+ // / intrinsic that returns either void or a vector.
184
183
static bool isSupportedInstruction (Instruction *I) {
184
+ Type *Ty = I->getType ();
185
185
if (auto *CI = dyn_cast<CallInst>(I))
186
- return CI->getCalledFunction () &&
186
+ return (Ty-> isVectorTy () || Ty-> isVoidTy ()) && CI->getCalledFunction () &&
187
187
CI->getCalledFunction ()->getIntrinsicID () !=
188
188
Intrinsic::not_intrinsic;
189
- if (I->getOpcode () == Instruction::FRem && I-> getType () ->isVectorTy ())
189
+ if (I->getOpcode () == Instruction::FRem && Ty ->isVectorTy ())
190
190
return true ;
191
191
return false ;
192
192
}
0 commit comments