Skip to content

[ctxprof] Handle instrumenting functions with musttail calls #135121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ class ContextNode final {
/// VOLATILE_PTRDECL is the same as above, but for volatile pointers;
///
/// MUTEXDECL takes one parameter, the name of a field that is a mutex.
#define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL) \
#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, VOLATILE_PTRDECL, \
MUTEXDECL) \
PTRDECL(FunctionData, Next) \
VOLATILE_PTRDECL(void, EntryAddress) \
VOLATILE_PTRDECL(ContextRoot, CtxRoot) \
CONTEXT_PTR \
VOLATILE_PTRDECL(ContextNode, FlatCtx) \
MUTEXDECL(Mutex)

Expand Down
32 changes: 20 additions & 12 deletions compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ void setupContext(ContextRoot *Root, GUID Guid, uint32_t NumCounters,

ContextRoot *FunctionData::getOrAllocateContextRoot() {
auto *Root = CtxRoot;
if (!canBeRoot(Root))
return Root;
if (Root)
return Root;
__sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> L(&Mutex);
Expand Down Expand Up @@ -328,8 +330,10 @@ ContextNode *getUnhandledContext(FunctionData &Data, void *Callee, GUID Guid,
if (!CtxRoot) {
if (auto *RAD = getRootDetector())
RAD->sample();
else if (auto *CR = Data.CtxRoot)
return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
else if (auto *CR = Data.CtxRoot) {
if (canBeRoot(CR))
return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
}
if (IsUnderContext || !__sanitizer::atomic_load_relaxed(&ProfilingStarted))
return TheScratchContext;
else
Expand Down Expand Up @@ -404,20 +408,21 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
ContextNode *__llvm_ctx_profile_start_context(FunctionData *FData, GUID Guid,
uint32_t Counters,
uint32_t Callsites) {

return tryStartContextGivenRoot(FData->getOrAllocateContextRoot(), Guid,
Counters, Callsites);
auto *Root = FData->getOrAllocateContextRoot();
assert(canBeRoot(Root));
return tryStartContextGivenRoot(Root, Guid, Counters, Callsites);
}

void __llvm_ctx_profile_release_context(FunctionData *FData)
SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
const auto *CurrentRoot = __llvm_ctx_profile_current_context_root;
if (!CurrentRoot || FData->CtxRoot != CurrentRoot)
auto *CR = FData->CtxRoot;
if (!CurrentRoot || CR != CurrentRoot)
return;
IsUnderContext = false;
assert(FData->CtxRoot);
assert(CR && canBeRoot(CR));
__llvm_ctx_profile_current_context_root = nullptr;
FData->CtxRoot->Taken.Unlock();
CR->Taken.Unlock();
}

void __llvm_ctx_profile_start_collection(unsigned AutodetectDuration) {
Expand Down Expand Up @@ -481,10 +486,13 @@ bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
// traversing it.
const auto *Pos = reinterpret_cast<const FunctionData *>(
__sanitizer::atomic_load_relaxed(&AllFunctionsData));
for (; Pos; Pos = Pos->Next)
if (!Pos->CtxRoot)
Writer.writeFlat(Pos->FlatCtx->guid(), Pos->FlatCtx->counters(),
Pos->FlatCtx->counters_size());
for (; Pos; Pos = Pos->Next) {
const auto *CR = Pos->CtxRoot;
if (!CR && canBeRoot(CR)) {
const auto *FP = Pos->FlatCtx;
Writer.writeFlat(FP->guid(), FP->counters(), FP->counters_size());
}
}
Writer.endFlatSection();
return true;
}
Expand Down
9 changes: 8 additions & 1 deletion compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ struct FunctionData {
#define _PTRDECL(T, N) T *N = nullptr;
#define _VOLATILE_PTRDECL(T, N) T *volatile N = nullptr;
#define _MUTEXDECL(N) ::__sanitizer::SpinMutex N;
CTXPROF_FUNCTION_DATA(_PTRDECL, _VOLATILE_PTRDECL, _MUTEXDECL)
#define _CONTEXT_PTR ContextRoot *CtxRoot = nullptr;
CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_PTR, _VOLATILE_PTRDECL, _MUTEXDECL)
#undef _CONTEXT_PTR
#undef _PTRDECL
#undef _VOLATILE_PTRDECL
#undef _MUTEXDECL
Expand All @@ -167,6 +169,11 @@ inline bool isScratch(const void *Ctx) {
return (reinterpret_cast<uint64_t>(Ctx) & 1);
}

// True if Ctx is either nullptr or not the 0x1 value.
inline bool canBeRoot(const ContextRoot *Ctx) {
return reinterpret_cast<uintptr_t>(Ctx) != 1U;
}

} // namespace __ctx_profile

extern "C" {
Expand Down
11 changes: 10 additions & 1 deletion compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@ void RootAutoDetector::start() {
atomic_load_relaxed(&RAD->FunctionDataListHead));
FD; FD = FD->Next) {
if (AllRoots.contains(reinterpret_cast<uptr>(FD->EntryAddress))) {
FD->getOrAllocateContextRoot();
if (canBeRoot(FD->CtxRoot)) {
FD->getOrAllocateContextRoot();
} else {
// FIXME: address this by informing the root detection algorithm
// to skip over such functions and pick the next down in the
// stack. At that point, this becomes an assert.
Printf("[ctxprof] Root auto-detector selected a musttail "
"function for root (%p). Ignoring\n",
FD->EntryAddress);
}
}
}
atomic_store_relaxed(&RAD->Self, 0);
Expand Down
10 changes: 10 additions & 0 deletions compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,13 @@ TEST_F(ContextTest, Dump) {
EXPECT_EQ(W2.FlatsWritten, 1);
EXPECT_EQ(W2.ExitedFlatCount, 1);
}

TEST_F(ContextTest, MustNotBeRoot) {
FunctionData FData;
FData.CtxRoot = reinterpret_cast<ContextRoot *>(1U);
int FakeCalleeAddress = 0;
__llvm_ctx_profile_start_collection();
auto *Subctx =
__llvm_ctx_profile_get_context(&FData, &FakeCalleeAddress, 2, 3, 1);
EXPECT_TRUE(isScratch(Subctx));
}
5 changes: 3 additions & 2 deletions llvm/include/llvm/ProfileData/CtxInstrContextNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ class ContextNode final {
/// VOLATILE_PTRDECL is the same as above, but for volatile pointers;
///
/// MUTEXDECL takes one parameter, the name of a field that is a mutex.
#define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL) \
#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, VOLATILE_PTRDECL, \
MUTEXDECL) \
PTRDECL(FunctionData, Next) \
VOLATILE_PTRDECL(void, EntryAddress) \
VOLATILE_PTRDECL(ContextRoot, CtxRoot) \
CONTEXT_PTR \
VOLATILE_PTRDECL(ContextNode, FlatCtx) \
MUTEXDECL(Mutex)

Expand Down
57 changes: 43 additions & 14 deletions llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Analysis.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -55,14 +57,15 @@ class CtxInstrumentationLowerer final {
Module &M;
ModuleAnalysisManager &MAM;
Type *ContextNodeTy = nullptr;
Type *FunctionDataTy = nullptr;
StructType *FunctionDataTy = nullptr;

DenseSet<const Function *> ContextRootSet;
Function *StartCtx = nullptr;
Function *GetCtx = nullptr;
Function *ReleaseCtx = nullptr;
GlobalVariable *ExpectedCalleeTLS = nullptr;
GlobalVariable *CallsiteInfoTLS = nullptr;
Constant *CannotBeRootInitializer = nullptr;

public:
CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
Expand Down Expand Up @@ -117,12 +120,29 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,

#define _PTRDECL(_, __) PointerTy,
#define _VOLATILE_PTRDECL(_, __) PointerTy,
#define _CONTEXT_ROOT PointerTy,
#define _MUTEXDECL(_) SanitizerMutexType,

FunctionDataTy = StructType::get(
M.getContext(),
{CTXPROF_FUNCTION_DATA(_PTRDECL, _VOLATILE_PTRDECL, _MUTEXDECL)});
M.getContext(), {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
_VOLATILE_PTRDECL, _MUTEXDECL)});
#undef _PTRDECL
#undef _CONTEXT_ROOT
#undef _VOLATILE_PTRDECL
#undef _MUTEXDECL

#define _PTRDECL(_, __) Constant::getNullValue(PointerTy),
#define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __)
#define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType),
#define _CONTEXT_ROOT \
Constant::getIntegerValue( \
PointerTy, \
APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)),
CannotBeRootInitializer = ConstantStruct::get(
FunctionDataTy, {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
_VOLATILE_PTRDECL, _MUTEXDECL)});
#undef _PTRDECL
#undef _CONTEXT_ROOT
#undef _VOLATILE_PTRDECL
#undef _MUTEXDECL

Expand All @@ -134,8 +154,8 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
I32Ty, /*NumCallsites*/
});

// Define a global for each entrypoint. We'll reuse the entrypoint's name as
// prefix. We assume the entrypoint names to be unique.
// Define a global for each entrypoint. We'll reuse the entrypoint's name
// as prefix. We assume the entrypoint names to be unique.
for (const auto &Fname : ContextRoots) {
if (const auto *F = M.getFunction(Fname)) {
if (F->isDeclaration())
Expand All @@ -145,10 +165,10 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
for (const auto &I : BB)
if (const auto *CB = dyn_cast<CallBase>(&I))
if (CB->isMustTailCall()) {
M.getContext().emitError(
"The function " + Fname +
" was indicated as a context root, but it features musttail "
"calls, which is not supported.");
M.getContext().emitError("The function " + Fname +
" was indicated as a context root, "
"but it features musttail "
"calls, which is not supported.");
}
}
}
Expand Down Expand Up @@ -240,6 +260,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
return false;
}();

if (HasMusttail && ContextRootSet.contains(&F)) {
F.getContext().emitError(
"[ctx_prof] A function with musttail calls was explicitly requested as "
"root. That is not supported because we cannot instrument a return "
"instruction to release the context: " +
F.getName());
return false;
}
auto &Head = F.getEntryBlock();
for (auto &I : Head) {
// Find the increment intrinsic in the entry basic block.
Expand All @@ -263,9 +291,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
// regular function)
// Don't set a name, they end up taking a lot of space and we don't need
// them.

// Zero-initialize the FunctionData, except for functions that have
// musttail calls. There, we set the CtxRoot field to 1, which will be
// treated as a "can't be set as root".
TheRootFuctionData = new GlobalVariable(
M, FunctionDataTy, false, GlobalVariable::InternalLinkage,
Constant::getNullValue(FunctionDataTy));
HasMusttail ? CannotBeRootInitializer
: Constant::getNullValue(FunctionDataTy));

if (ContextRootSet.contains(&F)) {
Context = Builder.CreateCall(
Expand Down Expand Up @@ -366,10 +399,6 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
}
}
}
// FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
// to disallow this, (so this then stays as an error), another is to detect
// that and then do a wrapper or disallow the tail call. This only affects
// instrumentation, when we want to detect the call graph.
if (!HasMusttail && !ContextWasReleased)
F.getContext().emitError(
"[ctx_prof] A function that doesn't have musttail calls was "
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ declare void @bar()
; LOWERING: @[[GLOB4:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @[[GLOB5:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @[[GLOB6:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @[[GLOB7:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @[[GLOB7:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } { ptr null, ptr null, ptr inttoptr (i64 1 to ptr), ptr null, i8 0 }
;.
define void @foo(i32 %a, ptr %fct) {
; INSTRUMENT-LABEL: define void @foo(
Expand Down
Loading