Skip to content

Commit 4332484

Browse files
getVecDesc now prioritizes masked variant
Also further cleanup to address reviewers.
1 parent f52c492 commit 4332484

File tree

2 files changed

+36
-51
lines changed

2 files changed

+36
-51
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,20 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
6969
return TLIFunc;
7070
}
7171

72-
/// Replace the call to the vector intrinsic ( \p OldFunc ) with a call to the
73-
/// corresponding function from the vector library ( \p TLIFunc ).
74-
static bool replaceWithTLIFunction(const Module *M, CallInst &CI,
75-
const ElementCount &VecVF, Function *OldFunc,
76-
Function *TLIFunc, FunctionType *VecFTy,
77-
bool IsMasked) {
72+
/// Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
73+
/// the corresponding function from the vector library ( \p TLIFunc ).
74+
static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
75+
Function *TLIFunc, FunctionType *VecFTy) {
7876
IRBuilder<> IRBuilder(&CI);
7977
SmallVector<Value *> Args(CI.args());
80-
if (IsMasked) {
78+
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
8179
if (Args.size() == VecFTy->getNumParams())
8280
static_assert(true && "mask was already in place");
8381

84-
auto *MaskTy = VectorType::get(Type::getInt1Ty(M->getContext()), VecVF);
85-
Args.push_back(Constant::getAllOnesValue(MaskTy));
82+
auto *MaskTy =
83+
VectorType::get(Type::getInt1Ty(CI.getContext()), Info.Shape.VF);
84+
Args.insert(Args.begin() + OptMaskpos.value(),
85+
Constant::getAllOnesValue(MaskTy));
8686
}
8787

8888
// Preserve the operand bundles.
@@ -95,26 +95,18 @@ static bool replaceWithTLIFunction(const Module *M, CallInst &CI,
9595
// Preserve fast math flags for FP math.
9696
if (isa<FPMathOperator>(Replacement))
9797
Replacement->copyFastMathFlags(&CI);
98-
99-
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
100-
<< OldFunc->getName() << "` with call to `"
101-
<< TLIFunc->getName() << "`.\n");
102-
++NumCallsReplaced;
103-
return true;
10498
}
10599

106-
/// Utility method to get the VecDesc, depending on whether there is a TLI
107-
/// mapping, either with or without a mask.
100+
/// Utility method to get the VecDesc, depending on whether there is such a TLI
101+
/// mapping, prioritizing a masked version.
108102
static std::optional<const VecDesc *> getVecDesc(const TargetLibraryInfo &TLI,
109103
const StringRef &ScalarName,
110104
const ElementCount &VF) {
111-
const VecDesc *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true);
112-
const VecDesc *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false);
113-
// Invalid when there are both variants (ie masked and unmasked), or none
114-
if ((VDMasked == nullptr) == (VDNoMask == nullptr))
115-
return std::nullopt;
116-
117-
return {VDMasked != nullptr ? VDMasked : VDNoMask};
105+
if (auto *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true))
106+
return VDMasked;
107+
if (auto *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false))
108+
return VDNoMask;
109+
return std::nullopt;
118110
}
119111

120112
/// Returns whether it is able to replace a call to the intrinsic \p CI with a
@@ -146,10 +138,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
146138
VF = VectorArgTy->getElementCount();
147139
else if (VF != VectorArgTy->getElementCount())
148140
return false;
149-
} else {
141+
} else
150142
// enters when it is supposed to be a vector argument but it isn't.
151143
return false;
152-
}
153144
}
154145

155146
// Try to reconstruct the name for the scalar version of this intrinsic using
@@ -164,44 +155,40 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
164155
if (!TLI.isFunctionVectorizable(ScalarName))
165156
return false;
166157

158+
// Try to find the mapping for the scalar version of this intrinsic and the
159+
// exact vector width of the call operands in the TargetLibraryInfo.
167160
auto OptVD = getVecDesc(TLI, ScalarName, VF);
168161
if (!OptVD)
169162
return false;
170163

171164
const VecDesc *VD = *OptVD;
172-
// Try to find the mapping for the scalar version of this intrinsic and the
173-
// exact vector width of the call operands in the TargetLibraryInfo.
174-
StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF, VD->isMasked());
175-
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
176-
<< ScalarName << "` and vector width " << VF << ".\n");
177-
178-
// TLI failed to find a correct mapping.
179-
if (TLIName.empty())
180-
return false;
181-
182-
// Find the vector Function and replace the call to the intrinsic with a call
183-
// to the vector library function.
184-
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
185-
<< "`.\n");
165+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
166+
<< "` and vector width " << VF << " to: `"
167+
<< VD->getVectorFnName() << "`.\n");
186168

169+
// Replace the call to the intrinsic with a call to the vector library
170+
// function.
187171
Type *ScalarRetTy = CI.getType()->getScalarType();
188172
FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarTypes, false);
189173
const std::string MangledName = VD->getVectorFunctionABIVariantString();
190174
auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
191175
if (!OptInfo)
192176
return false;
193177

194-
// get the vector FunctionType
195-
Module *M = CI.getModule();
196-
auto OptFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
197-
if (!OptFTy)
178+
FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
179+
if (!VectorFTy)
198180
return false;
199181

200-
Function *OldFunc = CI.getCalledFunction();
201-
FunctionType *VectorFTy = *OptFTy;
202-
Function *TLIFunc = getTLIFunction(M, VectorFTy, OldFunc, TLIName);
203-
return replaceWithTLIFunction(M, CI, OptInfo->Shape.VF, OldFunc, TLIFunc,
204-
VectorFTy, VD->isMasked());
182+
Function *FuncToReplace = CI.getCalledFunction();
183+
Function *TLIFunc = getTLIFunction(CI.getModule(), VectorFTy, FuncToReplace,
184+
VD->getVectorFnName());
185+
replaceWithTLIFunction(CI, *OptInfo, TLIFunc, VectorFTy);
186+
187+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
188+
<< FuncToReplace->getName() << "` with call to `"
189+
<< TLIFunc->getName() << "`.\n");
190+
++NumCallsReplaced;
191+
return true;
205192
}
206193

207194
static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {

llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
target triple = "aarch64-unknown-linux-gnu"
55

6-
; NOTE: The existing TLI mappings are not used since the -replace-with-veclib pass is broken for scalable vectors.
7-
86
;.
97
; CHECK: @llvm.compiler.used = appending global [16 x ptr] [ptr @_ZGVsMxv_cos, ptr @_ZGVsMxv_cosf, ptr @_ZGVsMxv_exp, ptr @_ZGVsMxv_expf, ptr @_ZGVsMxv_exp2, ptr @_ZGVsMxv_exp2f, ptr @_ZGVsMxv_exp10, ptr @_ZGVsMxv_exp10f, ptr @_ZGVsMxv_log, ptr @_ZGVsMxv_logf, ptr @_ZGVsMxv_log10, ptr @_ZGVsMxv_log10f, ptr @_ZGVsMxv_log2, ptr @_ZGVsMxv_log2f, ptr @_ZGVsMxv_sin, ptr @_ZGVsMxv_sinf], section "llvm.metadata"
108
;.

0 commit comments

Comments
 (0)