Skip to content

Commit 8082f46

Browse files
Addressing review.
getTLIFunction is no longer an optional. It accepts a pointer for ScalarFunc
1 parent 4332484 commit 8082f46

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,24 @@ STATISTIC(NumFuncUsedAdded,
4343
"Number of functions added to `llvm.compiler.used`");
4444

4545
/// Returns a vector Function that it adds to the Module \p M. When an \p
46-
/// OptOldFunc is given, it copies its attributes to the newly created Function.
46+
/// ScalarFunc is not null, it copies its attributes to the newly created
47+
/// Function.
4748
Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
48-
std::optional<Function *> OptOldFunc,
49-
const StringRef TLIName) {
49+
Function *ScalarFunc, const StringRef TLIName) {
5050
Function *TLIFunc = M->getFunction(TLIName);
5151
if (!TLIFunc) {
5252
TLIFunc =
5353
Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
54-
if (OptOldFunc)
55-
TLIFunc->copyAttributesFrom(*OptOldFunc);
54+
if (ScalarFunc)
55+
TLIFunc->copyAttributesFrom(ScalarFunc);
5656

5757
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
5858
<< TLIName << "` of type `" << *(TLIFunc->getType())
5959
<< "` to module.\n");
6060

6161
++NumTLIFuncDeclAdded;
6262
// Add the freshly created function to llvm.compiler.used, similar to as it
63-
// is done in InjectTLIMappings
63+
// is done in InjectTLIMappings.
6464
appendToCompilerUsed(*M, {TLIFunc});
6565
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
6666
<< "` to `@llvm.compiler.used`.\n");
@@ -72,11 +72,11 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
7272
/// Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
7373
/// the corresponding function from the vector library ( \p TLIFunc ).
7474
static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
75-
Function *TLIFunc, FunctionType *VecFTy) {
75+
Function *TLIVecFunc) {
7676
IRBuilder<> IRBuilder(&CI);
7777
SmallVector<Value *> Args(CI.args());
7878
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
79-
if (Args.size() == VecFTy->getNumParams())
79+
if (Args.size() == TLIVecFunc->getFunctionType()->getNumParams())
8080
static_assert(true && "mask was already in place");
8181

8282
auto *MaskTy =
@@ -88,9 +88,7 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
8888
// Preserve the operand bundles.
8989
SmallVector<OperandBundleDef, 1> OpBundles;
9090
CI.getOperandBundlesAsDefs(OpBundles);
91-
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
92-
assert(VecFTy == TLIFunc->getFunctionType() &&
93-
"Expecting function types to be identical");
91+
CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
9492
CI.replaceAllUsesWith(Replacement);
9593
// Preserve fast math flags for FP math.
9694
if (isa<FPMathOperator>(Replacement))
@@ -102,10 +100,10 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
102100
static std::optional<const VecDesc *> getVecDesc(const TargetLibraryInfo &TLI,
103101
const StringRef &ScalarName,
104102
const ElementCount &VF) {
105-
if (auto *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true))
106-
return VDMasked;
107103
if (auto *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false))
108104
return VDNoMask;
105+
if (auto *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true))
106+
return VDMasked;
109107
return std::nullopt;
110108
}
111109

@@ -117,20 +115,20 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
117115
return false;
118116

119117
auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
120-
// Replacement is only performed for intrinsic functions
118+
// Replacement is only performed for intrinsic functions.
121119
if (IntrinsicID == Intrinsic::not_intrinsic)
122120
return false;
123121

124122
// Convert vector arguments to scalar type and check that all vector operands
125123
// have identical vector width.
126124
ElementCount VF = ElementCount::getFixed(0);
127-
SmallVector<Type *> ScalarTypes;
125+
SmallVector<Type *> ScalarArgTypes;
128126
for (auto Arg : enumerate(CI.args())) {
129127
auto *ArgTy = Arg.value()->getType();
130128
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
131-
ScalarTypes.push_back(ArgTy);
129+
ScalarArgTypes.push_back(ArgTy);
132130
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
133-
ScalarTypes.push_back(ArgTy->getScalarType());
131+
ScalarArgTypes.push_back(ArgTy->getScalarType());
134132
// Disallow vector arguments with different VFs. When processing the first
135133
// vector argument, store it's VF, and for the rest ensure that they match
136134
// it.
@@ -139,15 +137,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
139137
else if (VF != VectorArgTy->getElementCount())
140138
return false;
141139
} else
142-
// enters when it is supposed to be a vector argument but it isn't.
140+
// Exit when it is supposed to be a vector argument but it isn't.
143141
return false;
144142
}
145143

146144
// Try to reconstruct the name for the scalar version of this intrinsic using
147145
// the intrinsic ID and the argument types converted to scalar above.
148146
std::string ScalarName =
149147
(Intrinsic::isOverloaded(IntrinsicID)
150-
? Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule())
148+
? Intrinsic::getName(IntrinsicID, ScalarArgTypes, CI.getModule())
151149
: Intrinsic::getName(IntrinsicID).str());
152150

153151
// The TargetLibraryInfo does not contain a vectorized version of the scalar
@@ -169,7 +167,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
169167
// Replace the call to the intrinsic with a call to the vector library
170168
// function.
171169
Type *ScalarRetTy = CI.getType()->getScalarType();
172-
FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarTypes, false);
170+
FunctionType *ScalarFTy =
171+
FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
173172
const std::string MangledName = VD->getVectorFunctionABIVariantString();
174173
auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
175174
if (!OptInfo)
@@ -182,7 +181,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
182181
Function *FuncToReplace = CI.getCalledFunction();
183182
Function *TLIFunc = getTLIFunction(CI.getModule(), VectorFTy, FuncToReplace,
184183
VD->getVectorFnName());
185-
replaceWithTLIFunction(CI, *OptInfo, TLIFunc, VectorFTy);
184+
replaceWithTLIFunction(CI, *OptInfo, TLIFunc);
186185

187186
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
188187
<< FuncToReplace->getName() << "` with call to `"

0 commit comments

Comments
 (0)