Skip to content

[PGO] Add a unit test for the PGOInstrumentationGen pass #93636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/unittests/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(Coroutines)
add_subdirectory(Instrumentation)
add_subdirectory(IPO)
add_subdirectory(Scalar)
add_subdirectory(Utils)
Expand Down
16 changes: 16 additions & 0 deletions llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set(LLVM_LINK_COMPONENTS
Analysis
AsmParser
Core
Instrumentation
Passes
Support
)

add_llvm_unittest(InstrumentationTests
PGOInstrumentationTest.cpp
)

target_link_libraries(InstrumentationTests PRIVATE LLVMTestingSupport)

set_property(TARGET InstrumentationTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")
192 changes: 192 additions & 0 deletions llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Module.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/ProfileData/InstrProf.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include <tuple>

namespace {

using namespace llvm;

using testing::_;
using ::testing::DoDefault;
using ::testing::Invoke;
using ::testing::IsNull;
using ::testing::NotNull;
using ::testing::Ref;
using ::testing::Return;
using ::testing::Sequence;
using ::testing::Test;
using ::testing::TestParamInfo;
using ::testing::Values;
using ::testing::WithParamInterface;

template <typename Derived> class MockAnalysisHandleBase {
public:
class Analysis : public AnalysisInfoMixin<Analysis> {
public:
class Result {
public:
// Forward invalidation events to the mock handle.
bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) {
return Handle->invalidate(M, PA, Inv);
}

private:
explicit Result(Derived *Handle) : Handle(Handle) {}

friend MockAnalysisHandleBase;
Derived *Handle;
};

Result run(Module &M, ModuleAnalysisManager &AM) {
return Handle->run(M, AM);
}

private:
friend AnalysisInfoMixin<Analysis>;
friend MockAnalysisHandleBase;
static inline AnalysisKey Key;

Derived *Handle;

explicit Analysis(Derived *Handle) : Handle(Handle) {}
};

Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }

typename Analysis::Result getResult() {
return typename Analysis::Result(static_cast<Derived *>(this));
}

protected:
void setDefaults() {
ON_CALL(static_cast<Derived &>(*this), run(_, _))
.WillByDefault(Return(this->getResult()));
ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
.WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &) {
auto PAC = PA.template getChecker<Analysis>();
return !PAC.preserved() &&
!PAC.template preservedSet<AllAnalysesOn<Module>>();
}));
}

private:
friend Derived;
MockAnalysisHandleBase() = default;
};

class MockModuleAnalysisHandle
: public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
public:
MockModuleAnalysisHandle() { setDefaults(); }

MOCK_METHOD(typename Analysis::Result, run,
(Module &, ModuleAnalysisManager &));

MOCK_METHOD(bool, invalidate,
(Module &, const PreservedAnalyses &,
ModuleAnalysisManager::Invalidator &));
};

struct PGOInstrumentationGenTest
: public Test,
WithParamInterface<std::tuple<StringRef, StringRef>> {
LLVMContext Ctx;
ModulePassManager MPM;
PassBuilder PB;
MockModuleAnalysisHandle MMAHandle;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
LLVMContext Context;
std::unique_ptr<Module> M;

PGOInstrumentationGenTest() {
MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
MPM.addPass(
RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
MPM.addPass(PGOInstrumentationGen());
}

void parseAssembly(const StringRef IR) {
SMDiagnostic Error;
M = parseAssemblyString(IR, Error, Context);
std::string ErrMsg;
raw_string_ostream OS(ErrMsg);
Error.print("", OS);

// A failure here means that the test itself is buggy.
if (!M)
report_fatal_error(OS.str().c_str());
}
};

static constexpr StringRef CodeWithFuncDefs = R"(
define i32 @f(i32 %n) {
entry:
ret i32 0
})";

static constexpr StringRef CodeWithFuncDecls = R"(
declare i32 @f(i32);
)";

static constexpr StringRef CodeWithGlobals = R"(
@foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
declare i32 @f(i32);
)";

INSTANTIATE_TEST_SUITE_P(
PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
std::make_tuple(CodeWithGlobals, "instrument_globals")),
[](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
return std::get<1>(Info.param).str();
});

TEST_P(PGOInstrumentationGenTest, Instrumented) {
const StringRef Code = std::get<0>(GetParam());
parseAssembly(Code);

ASSERT_THAT(M, NotNull());

Sequence PassSequence;
EXPECT_CALL(MMAHandle, run(Ref(*M), _))
.InSequence(PassSequence)
.WillOnce(DoDefault());
EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
.InSequence(PassSequence)
.WillOnce(DoDefault());

MPM.run(*M, MAM);

const auto *IRInstrVar =
M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
EXPECT_THAT(IRInstrVar, NotNull());
EXPECT_FALSE(IRInstrVar->isDeclaration());
}

} // end anonymous namespace
Loading