Skip to content

Commit b24f0e5

Browse files
Use createFunctionType to correctly replace veclib calls.
Split replaceWithTLIFunction method into two methods.
1 parent 8371909 commit b24f0e5

File tree

3 files changed

+149
-118
lines changed

3 files changed

+149
-118
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 117 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
#include "llvm/CodeGen/ReplaceWithVeclib.h"
1616
#include "llvm/ADT/STLExtras.h"
1717
#include "llvm/ADT/Statistic.h"
18+
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/Analysis/DemandedBits.h"
1920
#include "llvm/Analysis/GlobalsModRef.h"
2021
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
2122
#include "llvm/Analysis/TargetLibraryInfo.h"
2223
#include "llvm/Analysis/VectorUtils.h"
2324
#include "llvm/CodeGen/Passes.h"
25+
#include "llvm/IR/DerivedTypes.h"
2426
#include "llvm/IR/IRBuilder.h"
2527
#include "llvm/IR/InstIterator.h"
28+
#include "llvm/Support/TypeSize.h"
2629
#include "llvm/Transforms/Utils/ModuleUtils.h"
30+
#include <optional>
2731

2832
using namespace llvm;
2933

@@ -38,138 +42,166 @@ STATISTIC(NumTLIFuncDeclAdded,
3842
STATISTIC(NumFuncUsedAdded,
3943
"Number of functions added to `llvm.compiler.used`");
4044

41-
static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
42-
Module *M = CI.getModule();
43-
44-
Function *OldFunc = CI.getCalledFunction();
45-
46-
// Check if the vector library function is already declared in this module,
47-
// otherwise insert it.
45+
/// Returns a vector Function that it adds to the Module \p M. When an \p
46+
/// OptOldFunc is given, it copies its attributes to the newly created Function.
47+
Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
48+
std::optional<Function *> OptOldFunc,
49+
const StringRef TLIName) {
4850
Function *TLIFunc = M->getFunction(TLIName);
4951
if (!TLIFunc) {
50-
TLIFunc = Function::Create(OldFunc->getFunctionType(),
51-
Function::ExternalLinkage, TLIName, *M);
52-
TLIFunc->copyAttributesFrom(OldFunc);
52+
TLIFunc =
53+
Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
54+
if (OptOldFunc)
55+
TLIFunc->copyAttributesFrom(*OptOldFunc);
5356

5457
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
5558
<< TLIName << "` of type `" << *(TLIFunc->getType())
5659
<< "` to module.\n");
5760

5861
++NumTLIFuncDeclAdded;
59-
60-
// Add the freshly created function to llvm.compiler.used,
61-
// similar to as it is done in InjectTLIMappings
62+
// Add the freshly created function to llvm.compiler.used, similar to as it
63+
// is done in InjectTLIMappings
6264
appendToCompilerUsed(*M, {TLIFunc});
63-
6465
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
6566
<< "` to `@llvm.compiler.used`.\n");
6667
++NumFuncUsedAdded;
6768
}
69+
return TLIFunc;
70+
}
6871

69-
// Replace the call to the vector intrinsic with a call
70-
// to the corresponding function from the vector library.
72+
/// Replace the call to the vector intrinsic ( \p OldFunc ) with a call to the
73+
/// corresponding function from the vector library ( \p TLIFunc ).
74+
static bool replaceWithTLIFunction(const Module *M, CallInst &CI,
75+
const ElementCount &VecVF, Function *OldFunc,
76+
Function *TLIFunc, FunctionType *VecFTy,
77+
bool IsMasked) {
7178
IRBuilder<> IRBuilder(&CI);
7279
SmallVector<Value *> Args(CI.args());
80+
if (IsMasked) {
81+
if (Args.size() == VecFTy->getNumParams())
82+
static_assert(true && "mask was already in place");
83+
84+
auto *MaskTy = VectorType::get(Type::getInt1Ty(M->getContext()), VecVF);
85+
Args.push_back(Constant::getAllOnesValue(MaskTy));
86+
}
87+
7388
// Preserve the operand bundles.
7489
SmallVector<OperandBundleDef, 1> OpBundles;
7590
CI.getOperandBundlesAsDefs(OpBundles);
7691
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
77-
assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
92+
assert(VecFTy == TLIFunc->getFunctionType() &&
7893
"Expecting function types to be identical");
7994
CI.replaceAllUsesWith(Replacement);
80-
if (isa<FPMathOperator>(Replacement)) {
81-
// Preserve fast math flags for FP math.
95+
// Preserve fast math flags for FP math.
96+
if (isa<FPMathOperator>(Replacement))
8297
Replacement->copyFastMathFlags(&CI);
83-
}
8498

8599
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
86-
<< OldFunc->getName() << "` with call to `" << TLIName
87-
<< "`.\n");
100+
<< OldFunc->getName() << "` with call to `"
101+
<< TLIFunc->getName() << "`.\n");
88102
++NumCallsReplaced;
89103
return true;
90104
}
91105

106+
/// Utility method to get the VecDesc, depending on whether there is a TLI
107+
/// mapping, either with or without a mask.
108+
static std::optional<const VecDesc *> getVecDesc(const TargetLibraryInfo &TLI,
109+
const StringRef &ScalarName,
110+
const ElementCount &VF) {
111+
const VecDesc *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true);
112+
const VecDesc *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false);
113+
// Invalid when there are both variants (ie masked and unmasked), or none
114+
if ((VDMasked == nullptr) == (VDNoMask == nullptr))
115+
return std::nullopt;
116+
117+
return {VDMasked != nullptr ? VDMasked : VDNoMask};
118+
}
119+
120+
/// Returns whether it is able to replace a call to the intrinsic \p CI with a
121+
/// TLI mapped call.
92122
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
93123
CallInst &CI) {
94-
if (!CI.getCalledFunction()) {
124+
if (!CI.getCalledFunction())
95125
return false;
96-
}
97126

98127
auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
99-
if (IntrinsicID == Intrinsic::not_intrinsic) {
100-
// Replacement is only performed for intrinsic functions
128+
// Replacement is only performed for intrinsic functions
129+
if (IntrinsicID == Intrinsic::not_intrinsic)
101130
return false;
102-
}
103131

104-
// Convert vector arguments to scalar type and check that
105-
// all vector operands have identical vector width.
132+
// Convert vector arguments to scalar type and check that all vector operands
133+
// have identical vector width.
106134
ElementCount VF = ElementCount::getFixed(0);
107135
SmallVector<Type *> ScalarTypes;
108-
bool MayBeMasked = false;
109136
for (auto Arg : enumerate(CI.args())) {
110-
auto *ArgType = Arg.value()->getType();
111-
// Vector calls to intrinsics can still have
112-
// scalar operands for specific arguments.
137+
auto *ArgTy = Arg.value()->getType();
113138
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
114-
ScalarTypes.push_back(ArgType);
115-
} else {
116-
// The argument in this place should be a vector if
117-
// this is a call to a vector intrinsic.
118-
auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
119-
if (!VectorArgTy) {
120-
// The argument is not a vector, do not perform
121-
// the replacement.
122-
return false;
123-
}
124-
ElementCount NumElements = VectorArgTy->getElementCount();
125-
if (NumElements.isScalable())
126-
MayBeMasked = true;
127-
128-
// The different arguments differ in vector size.
129-
if (VF.isNonZero() && VF != NumElements)
139+
ScalarTypes.push_back(ArgTy);
140+
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
141+
ScalarTypes.push_back(ArgTy->getScalarType());
142+
// Disallow vector arguments with different VFs. When processing the first
143+
// vector argument, store it's VF, and for the rest ensure that they match
144+
// it.
145+
if (VF.isZero())
146+
VF = VectorArgTy->getElementCount();
147+
else if (VF != VectorArgTy->getElementCount())
130148
return false;
131-
VF = NumElements;
132-
ScalarTypes.push_back(VectorArgTy->getElementType());
149+
} else {
150+
// enters when it is supposed to be a vector argument but it isn't.
151+
return false;
133152
}
134153
}
135154

136-
// Try to reconstruct the name for the scalar version of this
137-
// intrinsic using the intrinsic ID and the argument types
138-
// converted to scalar above.
139-
std::string ScalarName;
140-
if (Intrinsic::isOverloaded(IntrinsicID)) {
141-
ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
142-
} else {
143-
ScalarName = Intrinsic::getName(IntrinsicID).str();
144-
}
155+
// Try to reconstruct the name for the scalar version of this intrinsic using
156+
// the intrinsic ID and the argument types converted to scalar above.
157+
std::string ScalarName =
158+
(Intrinsic::isOverloaded(IntrinsicID)
159+
? Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule())
160+
: Intrinsic::getName(IntrinsicID).str());
145161

146-
if (!TLI.isFunctionVectorizable(ScalarName)) {
147-
// The TargetLibraryInfo does not contain a vectorized version of
148-
// the scalar function.
162+
// The TargetLibraryInfo does not contain a vectorized version of the scalar
163+
// function.
164+
if (!TLI.isFunctionVectorizable(ScalarName))
149165
return false;
150-
}
151166

152-
// Assume it has a mask when that is a possibility and has no mapping for
153-
// a Non-Masked variant.
154-
const bool IsMasked =
155-
MayBeMasked && !TLI.getVectorMappingInfo(ScalarName, VF, false);
156-
// Try to find the mapping for the scalar version of this intrinsic
157-
// and the exact vector width of the call operands in the
158-
// TargetLibraryInfo.
159-
StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF, IsMasked);
167+
auto OptVD = getVecDesc(TLI, ScalarName, VF);
168+
if (!OptVD)
169+
return false;
170+
171+
const VecDesc *VD = *OptVD;
172+
// Try to find the mapping for the scalar version of this intrinsic and the
173+
// exact vector width of the call operands in the TargetLibraryInfo.
174+
StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF, VD->isMasked());
160175
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
161176
<< ScalarName << "` and vector width " << VF << ".\n");
162177

163-
if (!TLIName.empty()) {
164-
// Found the correct mapping in the TargetLibraryInfo,
165-
// replace the call to the intrinsic with a call to
166-
// the vector library function.
167-
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
168-
<< "`.\n");
169-
return replaceWithTLIFunction(CI, TLIName);
170-
}
178+
// TLI failed to find a correct mapping.
179+
if (TLIName.empty())
180+
return false;
171181

172-
return false;
182+
// Find the vector Function and replace the call to the intrinsic with a call
183+
// to the vector library function.
184+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
185+
<< "`.\n");
186+
187+
Type *ScalarRetTy = CI.getType()->getScalarType();
188+
FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarTypes, false);
189+
const std::string MangledName = VD->getVectorFunctionABIVariantString();
190+
auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
191+
if (!OptInfo)
192+
return false;
193+
194+
// get the vector FunctionType
195+
Module *M = CI.getModule();
196+
auto OptFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
197+
if (!OptFTy)
198+
return false;
199+
200+
Function *OldFunc = CI.getCalledFunction();
201+
FunctionType *VectorFTy = *OptFTy;
202+
Function *TLIFunc = getTLIFunction(M, VectorFTy, OldFunc, TLIName);
203+
return replaceWithTLIFunction(M, CI, OptInfo->Shape.VF, OldFunc, TLIFunc,
204+
VectorFTy, VD->isMasked());
173205
}
174206

175207
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
@@ -185,9 +217,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
185217
}
186218
// Erase the calls to the intrinsics that have been replaced
187219
// with calls to the vector library.
188-
for (auto *CI : ReplacedCalls) {
220+
for (auto *CI : ReplacedCalls)
189221
CI->eraseFromParent();
190-
}
191222
return Changed;
192223
}
193224

@@ -207,10 +238,10 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
207238
PA.preserve<DemandedBitsAnalysis>();
208239
PA.preserve<OptimizationRemarkEmitterAnalysis>();
209240
return PA;
210-
} else {
211-
// The pass did not replace any calls, hence it preserves all analyses.
212-
return PreservedAnalyses::all();
213241
}
242+
243+
// The pass did not replace any calls, hence it preserves all analyses.
244+
return PreservedAnalyses::all();
214245
}
215246

216247
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)