Skip to content

[BOLT][AArch64] Implement PLTCall optimization #93584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions bolt/include/bolt/Core/MCPlusBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1412,13 +1412,14 @@ class MCPlusBuilder {
return false;
}

/// Modify a direct call instruction \p Inst with an indirect call taking
/// a destination from a memory location pointed by \p TargetLocation symbol.
virtual bool convertCallToIndirectCall(MCInst &Inst,
const MCSymbol *TargetLocation,
MCContext *Ctx) {
/// Creates an indirect call to the function within the \p DirectCall PLT
/// stub. The function's memory location is pointed by the \p TargetLocation
/// symbol.
virtual InstructionListType
createIndirectPltCall(const MCInst &DirectCall,
const MCSymbol *TargetLocation, MCContext *Ctx) {
llvm_unreachable("not implemented");
return false;
return {};
}

/// Morph an indirect call into a load where \p Reg holds the call target.
Expand Down
19 changes: 11 additions & 8 deletions bolt/lib/Passes/PLTCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ Error PLTCall::runOnFunctions(BinaryContext &BC) {
return Error::success();

uint64_t NumCallsOptimized = 0;
for (auto &It : BC.getBinaryFunctions()) {
BinaryFunction &Function = It.second;
for (auto &BFI : BC.getBinaryFunctions()) {
BinaryFunction &Function = BFI.second;
if (!shouldOptimize(Function))
continue;

Expand All @@ -61,18 +61,21 @@ Error PLTCall::runOnFunctions(BinaryContext &BC) {
if (opts::PLT == OT_HOT && !BB.getKnownExecutionCount())
continue;

for (MCInst &Instr : BB) {
if (!BC.MIB->isCall(Instr))
for (auto II = BB.begin(); II != BB.end(); II++) {
if (!BC.MIB->isCall(*II))
continue;
const MCSymbol *CallSymbol = BC.MIB->getTargetSymbol(Instr);
const MCSymbol *CallSymbol = BC.MIB->getTargetSymbol(*II);
if (!CallSymbol)
continue;
const BinaryFunction *CalleeBF = BC.getFunctionForSymbol(CallSymbol);
if (!CalleeBF || !CalleeBF->isPLTFunction())
continue;
BC.MIB->convertCallToIndirectCall(Instr, CalleeBF->getPLTSymbol(),
BC.Ctx.get());
BC.MIB->addAnnotation(Instr, "PLTCall", true);
const InstructionListType NewCode = BC.MIB->createIndirectPltCall(
*II, CalleeBF->getPLTSymbol(), BC.Ctx.get());
II = BB.replaceInstruction(II, NewCode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::advance() after this line uses NewCode.size() - 1 as the second argument. I think it may make sense to ensure that the size is not zero adding a corresponding assert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, added.

assert(!NewCode.empty() && "PLT Call replacement must be non-empty");
std::advance(II, NewCode.size() - 1);
BC.MIB->addAnnotation(*II, "PLTCall", true);
++NumCallsOptimized;
}
}
Expand Down
41 changes: 41 additions & 0 deletions bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,47 @@ class AArch64MCPlusBuilder : public MCPlusBuilder {
return true;
}

InstructionListType createIndirectPltCall(const MCInst &DirectCall,
const MCSymbol *TargetLocation,
MCContext *Ctx) override {
const bool IsTailCall = isTailCall(DirectCall);
assert((DirectCall.getOpcode() == AArch64::BL ||
(DirectCall.getOpcode() == AArch64::B && IsTailCall)) &&
"64-bit direct (tail) call instruction expected");

InstructionListType Code;
// Code sequence for indirect plt call:
// adrp x16 <symbol>
// ldr x17, [x16, #<offset>]
// blr x17 ; or 'br' for tail calls

MCInst InstAdrp;
InstAdrp.setOpcode(AArch64::ADRP);
InstAdrp.addOperand(MCOperand::createReg(AArch64::X16));
InstAdrp.addOperand(MCOperand::createImm(0));
setOperandToSymbolRef(InstAdrp, /* OpNum */ 1, TargetLocation,
/* Addend */ 0, Ctx, ELF::R_AARCH64_ADR_GOT_PAGE);
Code.emplace_back(InstAdrp);

MCInst InstLoad;
InstLoad.setOpcode(AArch64::LDRXui);
InstLoad.addOperand(MCOperand::createReg(AArch64::X17));
InstLoad.addOperand(MCOperand::createReg(AArch64::X16));
InstLoad.addOperand(MCOperand::createImm(0));
setOperandToSymbolRef(InstLoad, /* OpNum */ 2, TargetLocation,
/* Addend */ 0, Ctx, ELF::R_AARCH64_LD64_GOT_LO12_NC);
Code.emplace_back(InstLoad);

MCInst InstCall;
InstCall.setOpcode(IsTailCall ? AArch64::BR : AArch64::BLR);
InstCall.addOperand(MCOperand::createReg(AArch64::X17));
if (IsTailCall)
setTailCall(InstCall);
Code.emplace_back(InstCall);

return Code;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks good for me now.

bool lowerTailCall(MCInst &Inst) override {
removeAnnotation(Inst, MCPlus::MCAnnotation::kTailCall);
if (getConditionalTailCall(Inst))
Expand Down
16 changes: 11 additions & 5 deletions bolt/lib/Target/X86/X86MCPlusBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1639,11 +1639,16 @@ class X86MCPlusBuilder : public MCPlusBuilder {
return true;
}

bool convertCallToIndirectCall(MCInst &Inst, const MCSymbol *TargetLocation,
MCContext *Ctx) override {
assert((Inst.getOpcode() == X86::CALL64pcrel32 ||
(Inst.getOpcode() == X86::JMP_4 && isTailCall(Inst))) &&
InstructionListType createIndirectPltCall(const MCInst &DirectCall,
const MCSymbol *TargetLocation,
MCContext *Ctx) override {
assert((DirectCall.getOpcode() == X86::CALL64pcrel32 ||
(DirectCall.getOpcode() == X86::JMP_4 && isTailCall(DirectCall))) &&
"64-bit direct (tail) call instruction expected");

InstructionListType Code;
// Create a new indirect call by converting the previous direct call.
MCInst Inst = DirectCall;
const auto NewOpcode =
(Inst.getOpcode() == X86::CALL64pcrel32) ? X86::CALL64m : X86::JMP32m;
Inst.setOpcode(NewOpcode);
Expand All @@ -1664,7 +1669,8 @@ class X86MCPlusBuilder : public MCPlusBuilder {
Inst.insert(Inst.begin(),
MCOperand::createReg(X86::RIP)); // BaseReg

return true;
Code.emplace_back(Inst);
return Code;
}

void convertIndirectCallToLoad(MCInst &Inst, MCPhysReg Reg) override {
Expand Down
15 changes: 15 additions & 0 deletions bolt/test/AArch64/plt-call.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Verify that PLTCall optimization works.

RUN: %clang %cflags %p/../Inputs/plt-tailcall.c \
RUN: -o %t -Wl,-q
RUN: llvm-bolt %t -o %t.bolt --plt=all --print-plt --print-only=foo | FileCheck %s

// Call to printf
CHECK: adrp x16, printf@GOT
CHECK: ldr x17, [x16, :lo12:printf@GOT]
CHECK: blr x17 # PLTCall: 1

// Call to puts, that was tail-call optimized
CHECK: adrp x16, puts@GOT
CHECK: ldr x17, [x16, :lo12:puts@GOT]
CHECK: br x17 # TAILCALL # PLTCall: 1
8 changes: 8 additions & 0 deletions bolt/test/Inputs/plt-tailcall.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "stub.h"

int foo(char *c) {
printf("");
__attribute__((musttail)) return puts(c);
}

int main() { return foo("a"); }
11 changes: 11 additions & 0 deletions bolt/test/X86/plt-call.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Verify that PLTCall optimization works.

RUN: %clang %cflags %p/../Inputs/plt-tailcall.c \
RUN: -o %t -Wl,-q
RUN: llvm-bolt %t -o %t.bolt --plt=all --print-plt --print-only=foo | FileCheck %s

// Call to printf
CHECK: callq *printf@GOT(%rip) # PLTCall: 1

// Call to puts, that was tail-call optimized
CHECK: jmpl *puts@GOT(%rip) # TAILCALL # PLTCall: 1
Loading