Skip to content

Commit 8e514c5

Browse files
Reapply [TLI] Fix replace-with-veclib crash with invalid arguments (#77945)
Fix a crash of `replace-with-veclib` pass, when the arguments of the TLI mapping do not match the original call. Now, it simply ignores such cases. Test require assertions as it accesses programmatically the debug log. Reapplies reverted PR #77112
1 parent a974303 commit 8e514c5

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
111111
SmallVector<Type *, 8> ScalarArgTypes;
112112
std::string ScalarName;
113113
Function *FuncToReplace = nullptr;
114-
if (auto *CI = dyn_cast<CallInst>(&I)) {
114+
auto *CI = dyn_cast<CallInst>(&I);
115+
if (CI) {
115116
FuncToReplace = CI->getCalledFunction();
116117
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
117118
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
@@ -168,12 +169,36 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
168169
if (!OptInfo)
169170
return false;
170171

172+
// There is no guarantee that the vectorized instructions followed the VFABI
173+
// specification when being created, this is why we need to add extra check to
174+
// make sure that the operands of the vector function obtained via VFABI match
175+
// the operands of the original vector instruction.
176+
if (CI) {
177+
for (auto VFParam : OptInfo->Shape.Parameters) {
178+
if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
179+
continue;
180+
181+
// tryDemangleForVFABI must return valid ParamPos, otherwise it could be
182+
// a bug in the VFABI parser.
183+
assert(VFParam.ParamPos < CI->arg_size() &&
184+
"ParamPos has invalid range.");
185+
Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
186+
if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
187+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
188+
<< ". Wrong type at index " << VFParam.ParamPos
189+
<< ": " << *OrigTy << "\n");
190+
return false;
191+
}
192+
}
193+
}
194+
171195
FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
172196
if (!VectorFTy)
173197
return false;
174198

175199
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
176200
VD->getVectorFnName(), FuncToReplace);
201+
177202
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
178203
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
179204
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,6 +245,9 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
220245
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
221246
auto Changed = runImpl(TLI, F);
222247
if (Changed) {
248+
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
249+
<< NumCallsReplaced << "\n");
250+
223251
PreservedAnalyses PA;
224252
PA.preserveSet<CFGAnalyses>();
225253
PA.preserve<TargetLibraryAnalysis>();

llvm/unittests/Analysis/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(LLVM_LINK_COMPONENTS
22
Analysis
33
AsmParser
4+
CodeGen
45
Core
56
Passes
67
Support
@@ -40,6 +41,7 @@ set(ANALYSIS_TEST_SOURCES
4041
PluginInlineAdvisorAnalysisTest.cpp
4142
PluginInlineOrderAnalysisTest.cpp
4243
ProfileSummaryInfoTest.cpp
44+
ReplaceWithVecLibTest.cpp
4345
ScalarEvolutionTest.cpp
4446
VectorFunctionABITest.cpp
4547
SparsePropagation.cpp
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===--- ReplaceWithVecLibTest.cpp - replace-with-veclib unit tests -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/CodeGen/ReplaceWithVeclib.h"
10+
#include "llvm/Analysis/TargetLibraryInfo.h"
11+
#include "llvm/AsmParser/Parser.h"
12+
#include "llvm/IR/LLVMContext.h"
13+
#include "llvm/IR/Module.h"
14+
#include "llvm/Passes/PassBuilder.h"
15+
#include "llvm/Support/SourceMgr.h"
16+
#include "gtest/gtest.h"
17+
18+
using namespace llvm;
19+
20+
/// NOTE: Assertions must be enabled for these tests to run.
21+
#ifndef NDEBUG
22+
23+
namespace {
24+
25+
static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
26+
SMDiagnostic Err;
27+
std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
28+
if (!Mod)
29+
Err.print("ReplaceWithVecLibTest", errs());
30+
return Mod;
31+
}
32+
33+
/// Runs ReplaceWithVecLib with different TLIIs that have custom VecDescs. This
34+
/// allows checking that the pass won't crash when the function to replace (from
35+
/// the input IR) does not match the replacement function (derived from the
36+
/// VecDesc mapping).
37+
///
38+
class ReplaceWithVecLibTest : public ::testing::Test {
39+
40+
std::string getLastLine(std::string Out) {
41+
// remove any trailing '\n'
42+
if (!Out.empty() && *(Out.cend() - 1) == '\n')
43+
Out.pop_back();
44+
45+
size_t LastNL = Out.find_last_of('\n');
46+
return (LastNL == std::string::npos) ? Out : Out.substr(LastNL + 1);
47+
}
48+
49+
protected:
50+
LLVMContext Ctx;
51+
52+
/// Creates TLII using the given \p VD, and then runs the ReplaceWithVeclib
53+
/// pass. The pass should not crash even when the replacement function
54+
/// (derived from the \p VD mapping) does not match the function to be
55+
/// replaced (from the input \p IR).
56+
///
57+
/// \returns the last line of the standard error to be compared for
58+
/// correctness.
59+
std::string run(const VecDesc &VD, const char *IR) {
60+
// Create TLII and register it with FAM so it's preserved when
61+
// ReplaceWithVeclib pass runs.
62+
TargetLibraryInfoImpl TLII = TargetLibraryInfoImpl(Triple());
63+
TLII.addVectorizableFunctions({VD});
64+
FunctionAnalysisManager FAM;
65+
FAM.registerPass([&TLII]() { return TargetLibraryAnalysis(TLII); });
66+
67+
// Register and run the pass on the 'foo' function from the input IR.
68+
FunctionPassManager FPM;
69+
FPM.addPass(ReplaceWithVeclib());
70+
std::unique_ptr<Module> M = parseIR(Ctx, IR);
71+
PassBuilder PB;
72+
PB.registerFunctionAnalyses(FAM);
73+
74+
// Enable debugging and capture std error
75+
llvm::DebugFlag = true;
76+
testing::internal::CaptureStderr();
77+
FPM.run(*M->getFunction("foo"), FAM);
78+
return getLastLine(testing::internal::GetCapturedStderr());
79+
}
80+
};
81+
82+
} // end anonymous namespace
83+
84+
static const char *IR = R"IR(
85+
define <vscale x 4 x float> @foo(<vscale x 4 x float> %in){
86+
%call = call <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float> %in, i32 3)
87+
ret <vscale x 4 x float> %call
88+
}
89+
90+
declare <vscale x 4 x float> @llvm.powi.f32.i32(<vscale x 4 x float>, i32) #0
91+
)IR";
92+
93+
// The VFABI prefix in TLI describes signature which is matching the powi
94+
// intrinsic declaration.
95+
TEST_F(ReplaceWithVecLibTest, TestValidMapping) {
96+
VecDesc CorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvu_powi",
97+
ElementCount::getScalable(4), /*Masked*/ true,
98+
"_ZGVsMxvu"};
99+
EXPECT_EQ(run(CorrectVD, IR),
100+
"Instructions replaced with vector libraries: 1");
101+
}
102+
103+
// The VFABI prefix in TLI describes signature which is not matching the powi
104+
// intrinsic declaration.
105+
TEST_F(ReplaceWithVecLibTest, TestInvalidMapping) {
106+
VecDesc IncorrectVD = {"llvm.powi.f32.i32", "_ZGVsMxvv_powi",
107+
ElementCount::getScalable(4), /*Masked*/ true,
108+
"_ZGVsMxvv"};
109+
EXPECT_EQ(run(IncorrectVD, IR),
110+
"replace-with-veclib: Will not replace: llvm.powi.f32.i32. Wrong "
111+
"type at index 1: i32");
112+
}
113+
#endif

0 commit comments

Comments
 (0)