Skip to content

Commit 8ed527c

Browse files
Reapply [TLI] Fix replace-with-veclib crash with invalid arguments (#77112)
Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI mapping do not match the original call. Now, it simply ignores such cases. Test require assertions as it accesses programmatically the debug log. NOTE: Originally submitted by commit 9fdc568, which was reverted by a300b24, as it was causing some linking issues: https://lab.llvm.org/buildbot/#/builders/234/builds/17734
1 parent a300b24 commit 8ed527c

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
111111
SmallVector<Type *, 8> ScalarArgTypes;
112112
std::string ScalarName;
113113
Function *FuncToReplace = nullptr;
114-
if (auto *CI = dyn_cast<CallInst>(&I)) {
114+
auto *CI = dyn_cast<CallInst>(&I);
115+
if (CI) {
115116
FuncToReplace = CI->getCalledFunction();
116117
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
117118
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
@@ -168,12 +169,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
168169
if (!OptInfo)
169170
return false;
170171

172+
// There is no guarantee that the vectorized instructions followed the VFABI
173+
// specification when being created, this is why we need to add extra check to
174+
// make sure that the operands of the vector function obtained via VFABI match
175+
// the operands of the original vector instruction.
176+
if (CI) {
177+
for (auto VFParam : OptInfo->Shape.Parameters) {
178+
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
179+
continue;
180+
181+
// tryDemangleForVFABI must return valid ParamPos, otherwise it could be
182+
// a bug in the VFABI parser.
183+
assert(VFParam.ParamPos < CI->arg_size() &&
184+
"ParamPos has invalid range.");
185+
Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
186+
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
187+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
188+
<< ". Wrong type at index " << VFParam.ParamPos
189+
<< ": " << *OrigTy << "\n");
190+
return false;
191+
}
192+
}
193+
}
194+
171195
FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
172196
if (!VectorFTy)
173197
return false;
174198

175199
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
176200
VD->getVectorFnName(), FuncToReplace);
201+
177202
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
178203
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
179204
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
220245
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
221246
auto Changed = runImpl(TLI, F);
222247
if (Changed) {
248+
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
249+
<< NumCallsReplaced << "\n");
250+
223251
PreservedAnalyses PA;
224252
PA.preserveSet<CFGAnalyses>();
225253
PA.preserve<TargetLibraryAnalysis>();

llvm/unittests/Analysis/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(LLVM_LINK_COMPONENTS
22
Analysis
33
AsmParser
4+
CodeGen
45
Core
56
Passes
67
Support
@@ -40,6 +41,7 @@ set(ANALYSIS_TEST_SOURCES
4041
PluginInlineAdvisorAnalysisTest.cpp
4142
PluginInlineOrderAnalysisTest.cpp
4243
ProfileSummaryInfoTest.cpp
44+
ReplaceWithVecLibTest.cpp
4345
ScalarEvolutionTest.cpp
4446
VectorFunctionABITest.cpp
4547
SparsePropagation.cpp
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/CodeGen/ReplaceWithVeclib.h"
10+
#include "llvm/Analysis/TargetLibraryInfo.h"
11+
#include "llvm/AsmParser/Parser.h"
12+
#include "llvm/IR/LLVMContext.h"
13+
#include "llvm/IR/Module.h"
14+
#include "llvm/Passes/PassBuilder.h"
15+
#include "llvm/Support/SourceMgr.h"
16+
#include "gtest/gtest.h"
17+
18+
using namespace llvm;
19+
20+
/// NOTE: Assertions must be enabled for these tests to run.
21+
#ifndef NDEBUG
22+
23+
namespace {
24+
25+
static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
26+
SMDiagnostic Err;
27+
std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
28+
if (!Mod)
29+
Err.print("ReplaceWithVecLibTest", errs());
30+
return Mod;
31+
}
32+
33+
/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
34+
/// allows checking that the pass won't crash when the function to replace (from
35+
/// the input IR) does not match the replacement function (derived from the
36+
/// VecDesc mapping).
37+
///
38+
class ReplaceWithVecLibTest : public ::testing::Test {
39+
40+
std::string getLastLine(std::string Out) {
41+
// remove any trailing '\n'
42+
if (!Out.empty() && *(Out.cend() - 1) == '\n')
43+
Out.pop_back();
44+
45+
size_t LastNL = Out.find_last_of('\n');
46+
return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
47+
}
48+
49+
protected:
50+
LLVMContext Ctx;
51+
52+
/// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
53+
/// pass. The pass should not crash even when the replacement function
54+
/// (derived from the \p VD mapping) does not match the function to be
55+
/// replaced (from the input \p IR).
56+
///
57+
/// \returns the last line of the standard error to be compared for
58+
/// correctness.
59+
std::string run(const VecDesc &VD, const char *IR) {
60+
// Create TLII and register it with FAM so it's preserved when
61+
// ReplaceWithVeclib pass runs.
62+
TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
63+
TLII.addVectorizableFunctions({VD});
64+
FunctionAnalysisManager FAM;
65+
FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
66+
67+
// Register and run the pass on the 'foo' function from the input IR.
68+
FunctionPassManager FPM;
69+
FPM.addPass(ReplaceWithVeclib());
70+
std::unique_ptr<Module> M = parseIR(Ctx, IR);
71+
PassBuilder PB;
72+
PB.registerFunctionAnalyses(FAM);
73+
74+
// Enable debugging and capture std error
75+
llvm::DebugFlag = true;
76+
testing::internal::CaptureStderr();
77+
FPM.run(*M->getFunction("foo"), FAM);
78+
return getLastLine(testing::internal::GetCapturedStderr());
79+
}
80+
};
81+
82+
} // end anonymous namespace
83+
84+
static const char *IR = R"IR(
85+
define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
86+
%call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
87+
ret <vscale x 4 x float> %call
88+
}
89+
90+
declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
91+
)IR";
92+
93+
// The VFABI prefix in TLI describes signature which is matching the powi
94+
// intrinsic declaration.
95+
TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
96+
VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
97+
ElementCount::getScalable(4), /*Masked*/ true,
98+
"_ZGVsMxvu"};
99+
EXPECT_EQ(run(CorrectVD, IR),
100+
"Instructions replaced with vector libraries: 1");
101+
}
102+
103+
// The VFABI prefix in TLI describes signature which is not matching the powi
104+
// intrinsic declaration.
105+
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
106+
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
107+
ElementCount::getScalable(4), /*Masked*/ true,
108+
"_ZGVsMxvv"};
109+
EXPECT_EQ(run(IncorrectVD, IR),
110+
"replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
111+
"type at index 1: i32");
112+
}
113+
#endif

0 commit comments

Comments
 (0)