Skip to content

[IR] Do not store Function inside BlockAddress #137958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CodeGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2286,7 +2286,7 @@ llvm::BlockAddress *CodeGenFunction::GetAddrOfLabel(const LabelDecl *L) {

// Make sure the indirect branch includes all of the address-taken blocks.
IndirectBranch->addDestination(BB);
return llvm::BlockAddress::get(CurFn, BB);
return llvm::BlockAddress::get(CurFn->getType(), BB);
}

llvm::BasicBlock *CodeGenFunction::GetIndirectGotoBlock() {
Expand Down
15 changes: 10 additions & 5 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -893,9 +893,9 @@ class ConstantTargetNone final : public ConstantData {
class BlockAddress final : public Constant {
friend class Constant;

constexpr static IntrusiveOperandsAllocMarker AllocMarker{2};
constexpr static IntrusiveOperandsAllocMarker AllocMarker{1};

BlockAddress(Function *F, BasicBlock *BB);
BlockAddress(Type *Ty, BasicBlock *BB);

void *operator new(size_t S) { return User::operator new(S, AllocMarker); }

Expand All @@ -912,6 +912,11 @@ class BlockAddress final : public Constant {
/// block must be embedded into a function.
static BlockAddress *get(BasicBlock *BB);

/// Return a BlockAddress for the specified basic block, which may not be
/// part of a function. The specified type must match the type of the function
/// the block will be inserted into.
static BlockAddress *get(Type *Ty, BasicBlock *BB);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a corresponding helper for SandboxIR? It is the only way to create a blockaddress if it has not been inserted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say "no unless it's actually needed". We should avoid creating blockaddresses for non-inserted blocks if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think block address should be explicitly disallowed for unparented blocks


/// Lookup an existing \c BlockAddress constant for the given BasicBlock.
///
/// \returns 0 if \c !BB->hasAddressTaken(), otherwise the \c BlockAddress.
Expand All @@ -920,8 +925,8 @@ class BlockAddress final : public Constant {
/// Transparently provide more efficient getOperand methods.
DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value);

Function *getFunction() const { return (Function *)Op<0>().get(); }
BasicBlock *getBasicBlock() const { return (BasicBlock *)Op<1>().get(); }
BasicBlock *getBasicBlock() const { return cast<BasicBlock>(Op<0>().get()); }
Function *getFunction() const { return getBasicBlock()->getParent(); }

/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Value *V) {
Expand All @@ -931,7 +936,7 @@ class BlockAddress final : public Constant {

template <>
struct OperandTraits<BlockAddress>
: public FixedNumOperandTraits<BlockAddress, 2> {};
: public FixedNumOperandTraits<BlockAddress, 1> {};

DEFINE_TRANSPARENT_OPERAND_ACCESSORS(BlockAddress, Value)

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
FwdBBs[BBID] = BasicBlock::Create(Context);
BB = FwdBBs[BBID];
}
C = BlockAddress::get(Fn, BB);
C = BlockAddress::get(Fn->getType(), BB);
break;
}
case BitcodeConstant::ConstantStructOpcode: {
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/CodeGen/AsmPrinter/WinCFGuard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ static bool isPossibleIndirectCallTarget(const Function *F) {
const Value *FnOrCast = Users.pop_back_val();
for (const Use &U : FnOrCast->uses()) {
const User *FnUser = U.getUser();
if (isa<BlockAddress>(FnUser)) {
// Block addresses are illegal to call.
continue;
}
if (const auto *Call = dyn_cast<CallBase>(FnUser)) {
if ((!Call->isCallee(&U) || U.get() != F) &&
!Call->getFunction()->getName().ends_with("$exit_thunk")) {
Expand Down
56 changes: 20 additions & 36 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1891,77 +1891,61 @@ void PoisonValue::destroyConstantImpl() {
getContext().pImpl->PVConstants.erase(getType());
}

BlockAddress *BlockAddress::get(Type *Ty, BasicBlock *BB) {
BlockAddress *&BA = BB->getContext().pImpl->BlockAddresses[BB];
if (!BA)
BA = new BlockAddress(Ty, BB);
return BA;
}

BlockAddress *BlockAddress::get(BasicBlock *BB) {
assert(BB->getParent() && "Block must have a parent");
return get(BB->getParent(), BB);
return get(BB->getParent()->getType(), BB);
}

BlockAddress *BlockAddress::get(Function *F, BasicBlock *BB) {
BlockAddress *&BA =
F->getContext().pImpl->BlockAddresses[std::make_pair(F, BB)];
if (!BA)
BA = new BlockAddress(F, BB);

assert(BA->getFunction() == F && "Basic block moved between functions");
return BA;
assert(BB->getParent() == F && "Block not part of specified function");
return get(BB->getParent()->getType(), BB);
}

BlockAddress::BlockAddress(Function *F, BasicBlock *BB)
: Constant(PointerType::get(F->getContext(), F->getAddressSpace()),
Value::BlockAddressVal, AllocMarker) {
setOperand(0, F);
setOperand(1, BB);
BlockAddress::BlockAddress(Type *Ty, BasicBlock *BB)
: Constant(Ty, Value::BlockAddressVal, AllocMarker) {
setOperand(0, BB);
BB->AdjustBlockAddressRefCount(1);
}

BlockAddress *BlockAddress::lookup(const BasicBlock *BB) {
if (!BB->hasAddressTaken())
return nullptr;

const Function *F = BB->getParent();
assert(F && "Block must have a parent");
BlockAddress *BA =
F->getContext().pImpl->BlockAddresses.lookup(std::make_pair(F, BB));
BlockAddress *BA = BB->getContext().pImpl->BlockAddresses.lookup(BB);
assert(BA && "Refcount and block address map disagree!");
return BA;
}

/// Remove the constant from the constant table.
void BlockAddress::destroyConstantImpl() {
getFunction()->getType()->getContext().pImpl
->BlockAddresses.erase(std::make_pair(getFunction(), getBasicBlock()));
getType()->getContext().pImpl->BlockAddresses.erase(getBasicBlock());
getBasicBlock()->AdjustBlockAddressRefCount(-1);
}

Value *BlockAddress::handleOperandChangeImpl(Value *From, Value *To) {
// This could be replacing either the Basic Block or the Function. In either
// case, we have to remove the map entry.
Function *NewF = getFunction();
BasicBlock *NewBB = getBasicBlock();

if (From == NewF)
NewF = cast<Function>(To->stripPointerCasts());
else {
assert(From == NewBB && "From does not match any operand");
NewBB = cast<BasicBlock>(To);
}
assert(From == getBasicBlock());
BasicBlock *NewBB = cast<BasicBlock>(To);

// See if the 'new' entry already exists, if not, just update this in place
// and return early.
BlockAddress *&NewBA =
getContext().pImpl->BlockAddresses[std::make_pair(NewF, NewBB)];
BlockAddress *&NewBA = getContext().pImpl->BlockAddresses[NewBB];
if (NewBA)
return NewBA;

getBasicBlock()->AdjustBlockAddressRefCount(-1);

// Remove the old entry, this can't cause the map to rehash (just a
// tombstone will get added).
getContext().pImpl->BlockAddresses.erase(std::make_pair(getFunction(),
getBasicBlock()));
getContext().pImpl->BlockAddresses.erase(getBasicBlock());
NewBA = this;
setOperand(0, NewF);
setOperand(1, NewBB);
setOperand(0, NewBB);
getBasicBlock()->AdjustBlockAddressRefCount(1);

// If we just want to keep the existing value, then return null.
Expand Down
10 changes: 1 addition & 9 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,6 @@ bool Function::hasAddressTaken(const User **PutOffender,
bool IgnoreCastedDirectCall) const {
for (const Use &U : uses()) {
const User *FU = U.getUser();
if (isa<BlockAddress>(FU))
continue;

if (IgnoreCallbackUses) {
AbstractCallSite ACS(&U);
if (ACS && ACS.isCallbackCall())
Expand Down Expand Up @@ -1033,12 +1030,7 @@ bool Function::isDefTriviallyDead() const {
!hasAvailableExternallyLinkage())
return false;

// Check if the function is used by anything other than a blockaddress.
for (const User *U : users())
if (!isa<BlockAddress>(U))
return false;

return true;
return use_empty();
}

/// callsFunctionThatReturnsTwice - Return true if the function has a call to
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/IR/LLVMContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1688,8 +1688,7 @@ class LLVMContextImpl {

StringMap<std::unique_ptr<ConstantDataSequential>> CDSConstants;

DenseMap<std::pair<const Function *, const BasicBlock *>, BlockAddress *>
BlockAddresses;
DenseMap<const BasicBlock *, BlockAddress *> BlockAddresses;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on moving this out of the LLVMContext


DenseMap<const GlobalValue *, DSOLocalEquivalent *> DSOLocalEquivalents;

Expand Down
11 changes: 0 additions & 11 deletions llvm/lib/Transforms/IPO/Attributor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1936,9 +1936,6 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred,
LLVM_DEBUG(dbgs() << "[Attributor] Function " << Fn.getName()
<< " has non call site use " << *U.get() << " in "
<< *U.getUser() << "\n");
// BlockAddress users are allowed.
if (isa<BlockAddress>(U.getUser()))
continue;
return false;
}

Expand Down Expand Up @@ -3061,14 +3058,6 @@ ChangeStatus Attributor::rewriteFunctionSignatures(
// function empty.
NewFn->splice(NewFn->begin(), OldFn);

// Fixup block addresses to reference new function.
SmallVector<BlockAddress *, 8u> BlockAddresses;
for (User *U : OldFn->users())
if (auto *BA = dyn_cast<BlockAddress>(U))
BlockAddresses.push_back(BA);
for (auto *BA : BlockAddresses)
BA->replaceAllUsesWith(BlockAddress::get(NewFn, BA->getBasicBlock()));

// Set of all "call-like" instructions that invoke the old function mapped
// to their new replacements.
SmallVector<std::pair<CallBase *, CallBase *>, 8> CallSitePairs;
Expand Down
24 changes: 3 additions & 21 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1678,11 +1678,8 @@ processGlobal(GlobalValue &GV,
/// Walk all of the direct calls of the specified function, changing them to
/// FastCC.
static void ChangeCalleesToFastCall(Function *F) {
for (User *U : F->users()) {
if (isa<BlockAddress>(U))
continue;
for (User *U : F->users())
cast<CallBase>(U)->setCallingConv(CallingConv::Fast);
}
}

static AttributeList StripAttr(LLVMContext &C, AttributeList Attrs,
Expand All @@ -1696,8 +1693,6 @@ static AttributeList StripAttr(LLVMContext &C, AttributeList Attrs,
static void RemoveAttribute(Function *F, Attribute::AttrKind A) {
F->setAttributes(StripAttr(F->getContext(), F->getAttributes(), A));
for (User *U : F->users()) {
if (isa<BlockAddress>(U))
continue;
CallBase *CB = cast<CallBase>(U);
CB->setAttributes(StripAttr(F->getContext(), CB->getAttributes(), A));
}
Expand All @@ -1722,8 +1717,6 @@ static bool hasChangeableCCImpl(Function *F) {
// Can't change CC of the function that either has musttail calls, or is a
// musttail callee itself
for (User *U : F->users()) {
if (isa<BlockAddress>(U))
continue;
CallInst* CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
Expand Down Expand Up @@ -1772,9 +1765,6 @@ isValidCandidateForColdCC(Function &F,
return false;

for (User *U : F.users()) {
if (isa<BlockAddress>(U))
continue;

CallBase &CB = cast<CallBase>(*U);
Function *CallerFunc = CB.getParent()->getParent();
BlockFrequencyInfo &CallerBFI = GetBFI(*CallerFunc);
Expand All @@ -1787,11 +1777,8 @@ isValidCandidateForColdCC(Function &F,
}

static void changeCallSitesToColdCC(Function *F) {
for (User *U : F->users()) {
if (isa<BlockAddress>(U))
continue;
for (User *U : F->users())
cast<CallBase>(U)->setCallingConv(CallingConv::Cold);
}
}

// This function iterates over all the call instructions in the input Function
Expand Down Expand Up @@ -1832,12 +1819,7 @@ hasOnlyColdCalls(Function &F,

static bool hasMustTailCallers(Function *F) {
for (User *U : F->users()) {
CallBase *CB = dyn_cast<CallBase>(U);
if (!CB) {
assert(isa<BlockAddress>(U) &&
"Expected either CallBase or BlockAddress");
continue;
}
CallBase *CB = cast<CallBase>(U);
if (CB->isMustTailCall())
return true;
}
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/IPO/LowerTypeTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,9 +1920,9 @@ void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New,
bool IsJumpTableCanonical) {
SmallSetVector<Constant *, 4> Constants;
for (Use &U : llvm::make_early_inc_range(Old->uses())) {
// Skip block addresses and no_cfi values, which refer to the function
// body instead of the jump table.
if (isa<BlockAddress, NoCFIValue>(U.getUser()))
// Skip no_cfi values, which refer to the function body instead of the jump
// table.
if (isa<NoCFIValue>(U.getUser()))
continue;

// Skip direct calls to externally defined or non-dso_local functions.
Expand Down
5 changes: 1 addition & 4 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5734,10 +5734,7 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
auto IsCalled = [&](Function &F) {
if (Kernels.contains(&F))
return true;
for (const User *U : F.users())
if (!isa<BlockAddress>(U))
return true;
return false;
return !F.use_empty();
};

auto EmitRemark = [&](Function &F) {
Expand Down
7 changes: 0 additions & 7 deletions llvm/lib/Transforms/IPO/PartialInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,6 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap(
};

for (User *User : Users) {
// Don't bother with BlockAddress used by CallBr for asm goto.
if (isa<BlockAddress>(User))
continue;
CallBase *CB = getSupportedCallBase(User);
Function *Caller = CB->getCaller();
if (CurrentCaller != Caller) {
Expand Down Expand Up @@ -1359,10 +1356,6 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) {

bool AnyInline = false;
for (User *User : Users) {
// Don't bother with BlockAddress used by CallBr for asm goto.
if (isa<BlockAddress>(User))
continue;

CallBase *CB = getSupportedCallBase(User);

if (isLimitReached())
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Transforms/IPO/SCCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,10 @@ static bool runIPSCCP(
for (Use &U : F->uses()) {
CallBase *CB = dyn_cast<CallBase>(U.getUser());
if (!CB) {
assert(isa<BlockAddress>(U.getUser()) ||
(isa<Constant>(U.getUser()) &&
all_of(U.getUser()->users(), [](const User *UserUser) {
return cast<IntrinsicInst>(UserUser)->isAssumeLikeIntrinsic();
})));
assert(isa<Constant>(U.getUser()) &&
all_of(U.getUser()->users(), [](const User *UserUser) {
return cast<IntrinsicInst>(UserUser)->isAssumeLikeIntrinsic();
}));
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

; INTERESTING: @blockaddr.table.other

; RESULT: @blockaddr.table.other = private unnamed_addr constant [2 x ptr] [ptr blockaddress(@bar, %L1), ptr blockaddress(@bar, %L2)]
; RESULT: @blockaddr.table.other = private unnamed_addr constant [2 x ptr] [ptr inttoptr (i32 1 to ptr), ptr inttoptr (i32 1 to ptr)]


@blockaddr.table.other = private unnamed_addr constant [2 x ptr] [ptr blockaddress(@bar, %L1), ptr blockaddress(@bar, %L2)]


; RESULT: define i32 @bar(
; RESULT-NOT: define i32 @bar(
define i32 @bar(i64 %arg0) {
entry:
%gep = getelementptr inbounds [2 x ptr], ptr @blockaddr.table.other, i64 0, i64 %arg0
Expand Down
Loading