Skip to content

Commit b12449f

Browse files
authored
[CodeGen] Refactor and document ThunkInserter (#97468)
In preparation for supporting BLRA* instructions in SLS Hardening on AArch64, refactor ThunkInserter class. The main intention of this commit is to document the way to merge the BLR-rewriting logic of the AArch64SLSHardening pass into the SLSBLRThunkInserter class. This makes it possible to only call createThunkFunction for the thunks that are actually referenced. Ultimately, it will prevent SLSBLRThunkInserter from unconditionally generating about 1800 thunk functions corresponding to every possible combination of operands passed to BLRAA or BLRAB instructions. This particular commit does not affect the generated machine code and consists of the following changes: * document the existing behavior of ThunkInserter class * introduce ThunkInserterPass template class to get rid of mostly identical boilerplate code in ARM, AArch64 and X86 implementations * move the InsertedThunks parameter from `mayUseThunk` to `insertThunks` method
1 parent 0035c2e commit b12449f

File tree

4 files changed

+159
-142
lines changed

4 files changed

+159
-142
lines changed

llvm/include/llvm/CodeGen/IndirectThunks.h

Lines changed: 117 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===---- IndirectThunks.h - Indirect Thunk Base Class ----------*- C++ -*-===//
1+
//===---- IndirectThunks.h - Indirect thunk insertion helpers ---*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,34 +7,105 @@
77
//===----------------------------------------------------------------------===//
88
///
99
/// \file
10-
/// Contains a base class for Passes that inject an MI thunk.
10+
/// Contains a base ThunkInserter class that simplifies injection of MI thunks
11+
/// as well as a default implementation of MachineFunctionPass wrapping
12+
/// several `ThunkInserter`s for targets to extend.
1113
///
1214
//===----------------------------------------------------------------------===//
1315

1416
#ifndef LLVM_CODEGEN_INDIRECTTHUNKS_H
1517
#define LLVM_CODEGEN_INDIRECTTHUNKS_H
1618

1719
#include "llvm/CodeGen/MachineFunction.h"
20+
#include "llvm/CodeGen/MachineFunctionPass.h"
1821
#include "llvm/CodeGen/MachineModuleInfo.h"
1922
#include "llvm/IR/IRBuilder.h"
2023
#include "llvm/IR/Module.h"
2124

2225
namespace llvm {
2326

27+
/// This class assists in inserting MI thunk functions into the module and
28+
/// rewriting the existing machine functions to call these thunks.
29+
///
30+
/// One of the common cases is implementing security mitigations that involve
31+
/// replacing some machine code patterns with calls to special thunk functions.
32+
///
33+
/// Inserting a module pass late in the codegen pipeline may increase memory
34+
/// usage, as it serializes the transformations and forces preceding passes to
35+
/// produce machine code for all functions before running the module pass.
36+
/// For that reason, ThunkInserter can be driven by a MachineFunctionPass by
37+
/// passing one MachineFunction at a time to its `run(MMI, MF)` method.
38+
/// Then, the derived class should
39+
/// * call createThunkFunction from its insertThunks method exactly once for
40+
/// each of the thunk functions to be inserted
41+
/// * populate the thunk in its populateThunk method
42+
///
43+
/// Note that if some other pass is responsible for rewriting the functions,
44+
/// the insertThunks method may simply create all possible thunks at once,
45+
/// probably postponed until the first occurrence of possibly affected MF.
46+
///
47+
/// Alternatively, insertThunks method can rewrite MF by itself and only insert
48+
/// the thunks being called. In that case InsertedThunks variable can be used
49+
/// to track which thunks were already inserted.
50+
///
51+
/// In any case, the thunk function has to be inserted on behalf of some other
52+
/// function and then populated on its own "iteration" later - this is because
53+
/// MachineFunctionPass will see the newly created functions, but they first
54+
/// have to go through the preceding passes from the same pass manager,
55+
/// possibly even through the instruction selector.
56+
//
57+
// FIXME Maybe implement a documented and less surprising way of modifying
58+
// the module from a MachineFunctionPass that is restricted to inserting
59+
// completely new functions to the module.
2460
template <typename Derived, typename InsertedThunksTy = bool>
2561
class ThunkInserter {
2662
Derived &getDerived() { return *static_cast<Derived *>(this); }
2763

28-
protected:
2964
// A variable used to track whether (and possible which) thunks have been
3065
// inserted so far. InsertedThunksTy is usually a bool, but can be other types
3166
// to represent more than one type of thunk. Requires an |= operator to
3267
// accumulate results.
3368
InsertedThunksTy InsertedThunks;
34-
void doInitialization(Module &M) {}
69+
70+
protected:
71+
// Interface for subclasses to use.
72+
73+
/// Create an empty thunk function.
74+
///
75+
/// The new function will eventually be passed to populateThunk. If multiple
76+
/// thunks are created, populateThunk can distinguish them by their names.
3577
void createThunkFunction(MachineModuleInfo &MMI, StringRef Name,
3678
bool Comdat = true, StringRef TargetAttrs = "");
3779

80+
protected:
81+
// Interface for subclasses to implement.
82+
//
83+
// Note: all functions are non-virtual and are called via getDerived().
84+
// Note: only doInitialization() has an implementation.
85+
86+
/// Initializes thunk inserter.
87+
void doInitialization(Module &M) {}
88+
89+
/// Returns common prefix for thunk function's names.
90+
const char *getThunkPrefix(); // undefined
91+
92+
/// Checks if MF may use thunks (true - maybe, false - definitely not).
93+
bool mayUseThunk(const MachineFunction &MF); // undefined
94+
95+
/// Rewrites the function if necessary, returns the set of thunks added.
96+
InsertedThunksTy insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
97+
InsertedThunksTy ExistingThunks); // undefined
98+
99+
/// Populate the thunk function with instructions.
100+
///
101+
/// If multiple thunks are created, the content that must be inserted in the
102+
/// thunk function body should be derived from the MF's name.
103+
///
104+
/// Depending on the preceding passes in the pass manager, by the time
105+
/// populateThunk is called, MF may have a few target-specific instructions
106+
/// (such as a single MBB containing the return instruction).
107+
void populateThunk(MachineFunction &MF); // undefined
108+
38109
public:
39110
void init(Module &M) {
40111
InsertedThunks = InsertedThunksTy{};
@@ -53,7 +124,7 @@ void ThunkInserter<Derived, InsertedThunksTy>::createThunkFunction(
53124

54125
Module &M = const_cast<Module &>(*MMI.getModule());
55126
LLVMContext &Ctx = M.getContext();
56-
auto Type = FunctionType::get(Type::getVoidTy(Ctx), false);
127+
auto *Type = FunctionType::get(Type::getVoidTy(Ctx), false);
57128
Function *F = Function::Create(Type,
58129
Comdat ? GlobalValue::LinkOnceODRLinkage
59130
: GlobalValue::InternalLinkage,
@@ -95,19 +166,15 @@ bool ThunkInserter<Derived, InsertedThunksTy>::run(MachineModuleInfo &MMI,
95166
MachineFunction &MF) {
96167
// If MF is not a thunk, check to see if we need to insert a thunk.
97168
if (!MF.getName().starts_with(getDerived().getThunkPrefix())) {
98-
// Only add a thunk if one of the functions has the corresponding feature
99-
// enabled in its subtarget, and doesn't enable external thunks. The target
100-
// can use InsertedThunks to detect whether relevant thunks have already
101-
// been inserted.
102-
// FIXME: Conditionalize on indirect calls so we don't emit a thunk when
103-
// nothing will end up calling it.
104-
// FIXME: It's a little silly to look at every function just to enumerate
105-
// the subtargets, but eventually we'll want to look at them for indirect
106-
// calls, so maybe this is OK.
107-
if (!getDerived().mayUseThunk(MF, InsertedThunks))
169+
// Only add thunks if one of the functions may use them.
170+
if (!getDerived().mayUseThunk(MF))
108171
return false;
109172

110-
InsertedThunks |= getDerived().insertThunks(MMI, MF);
173+
// The target can use InsertedThunks to detect whether relevant thunks
174+
// have already been inserted.
175+
// FIXME: Provide the way for insertThunks to notify us whether it changed
176+
// the MF, instead of conservatively assuming it did.
177+
InsertedThunks |= getDerived().insertThunks(MMI, MF, InsertedThunks);
111178
return true;
112179
}
113180

@@ -116,6 +183,40 @@ bool ThunkInserter<Derived, InsertedThunksTy>::run(MachineModuleInfo &MMI,
116183
return true;
117184
}
118185

186+
/// Basic implementation of MachineFunctionPass wrapping one or more
187+
/// `ThunkInserter`s passed as type parameters.
188+
template <typename... Inserters>
189+
class ThunkInserterPass : public MachineFunctionPass {
190+
protected:
191+
std::tuple<Inserters...> TIs;
192+
193+
ThunkInserterPass(char &ID) : MachineFunctionPass(ID) {}
194+
195+
public:
196+
bool doInitialization(Module &M) override {
197+
initTIs(M, TIs);
198+
return false;
199+
}
200+
201+
bool runOnMachineFunction(MachineFunction &MF) override {
202+
auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
203+
return runTIs(MMI, MF, TIs);
204+
}
205+
206+
private:
207+
template <typename... ThunkInserterT>
208+
static void initTIs(Module &M,
209+
std::tuple<ThunkInserterT...> &ThunkInserters) {
210+
(..., std::get<ThunkInserterT>(ThunkInserters).init(M));
211+
}
212+
213+
template <typename... ThunkInserterT>
214+
static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
215+
std::tuple<ThunkInserterT...> &ThunkInserters) {
216+
return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
217+
}
218+
};
219+
119220
} // namespace llvm
120221

121222
#endif

llvm/lib/Target/AArch64/AArch64SLSHardening.cpp

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,12 @@ static const struct ThunkNameAndReg {
183183
namespace {
184184
struct SLSBLRThunkInserter : ThunkInserter<SLSBLRThunkInserter> {
185185
const char *getThunkPrefix() { return SLSBLRNamePrefix; }
186-
bool mayUseThunk(const MachineFunction &MF, bool InsertedThunks) {
187-
if (InsertedThunks)
188-
return false;
186+
bool mayUseThunk(const MachineFunction &MF) {
189187
ComdatThunks &= !MF.getSubtarget<AArch64Subtarget>().hardenSlsNoComdat();
190-
// FIXME: This could also check if there are any BLRs in the function
191-
// to more accurately reflect if a thunk will be needed.
192188
return MF.getSubtarget<AArch64Subtarget>().hardenSlsBlr();
193189
}
194-
bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
190+
bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
191+
bool ExistingThunks);
195192
void populateThunk(MachineFunction &MF);
196193

197194
private:
@@ -200,7 +197,10 @@ struct SLSBLRThunkInserter : ThunkInserter<SLSBLRThunkInserter> {
200197
} // namespace
201198

202199
bool SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
203-
MachineFunction &MF) {
200+
MachineFunction &MF,
201+
bool ExistingThunks) {
202+
if (ExistingThunks)
203+
return false;
204204
// FIXME: It probably would be possible to filter which thunks to produce
205205
// based on which registers are actually used in BLR instructions in this
206206
// function. But would that be a worthwhile optimization?
@@ -210,6 +210,8 @@ bool SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
210210
}
211211

212212
void SLSBLRThunkInserter::populateThunk(MachineFunction &MF) {
213+
assert(MF.getFunction().hasComdat() == ComdatThunks &&
214+
"ComdatThunks value changed since MF creation");
213215
// FIXME: How to better communicate Register number, rather than through
214216
// name and lookup table?
215217
assert(MF.getName().starts_with(getThunkPrefix()));
@@ -411,30 +413,13 @@ FunctionPass *llvm::createAArch64SLSHardeningPass() {
411413
}
412414

413415
namespace {
414-
class AArch64IndirectThunks : public MachineFunctionPass {
416+
class AArch64IndirectThunks : public ThunkInserterPass<SLSBLRThunkInserter> {
415417
public:
416418
static char ID;
417419

418-
AArch64IndirectThunks() : MachineFunctionPass(ID) {}
420+
AArch64IndirectThunks() : ThunkInserterPass(ID) {}
419421

420422
StringRef getPassName() const override { return "AArch64 Indirect Thunks"; }
421-
422-
bool doInitialization(Module &M) override;
423-
bool runOnMachineFunction(MachineFunction &MF) override;
424-
425-
private:
426-
std::tuple<SLSBLRThunkInserter> TIs;
427-
428-
template <typename... ThunkInserterT>
429-
static void initTIs(Module &M,
430-
std::tuple<ThunkInserterT...> &ThunkInserters) {
431-
(..., std::get<ThunkInserterT>(ThunkInserters).init(M));
432-
}
433-
template <typename... ThunkInserterT>
434-
static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
435-
std::tuple<ThunkInserterT...> &ThunkInserters) {
436-
return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
437-
}
438423
};
439424

440425
} // end anonymous namespace
@@ -444,14 +429,3 @@ char AArch64IndirectThunks::ID = 0;
444429
FunctionPass *llvm::createAArch64IndirectThunks() {
445430
return new AArch64IndirectThunks();
446431
}
447-
448-
bool AArch64IndirectThunks::doInitialization(Module &M) {
449-
initTIs(M, TIs);
450-
return false;
451-
}
452-
453-
bool AArch64IndirectThunks::runOnMachineFunction(MachineFunction &MF) {
454-
LLVM_DEBUG(dbgs() << getPassName() << '\n');
455-
auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
456-
return runTIs(MMI, MF, TIs);
457-
}

llvm/lib/Target/ARM/ARMSLSHardening.cpp

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ static const struct ThunkNameRegMode {
163163

164164
// An enum for tracking whether Arm and Thumb thunks have been inserted into the
165165
// current module so far.
166-
enum ArmInsertedThunks { ArmThunk = 1, ThumbThunk = 2 };
166+
enum ArmInsertedThunks { NoThunk = 0, ArmThunk = 1, ThumbThunk = 2 };
167167

168168
inline ArmInsertedThunks &operator|=(ArmInsertedThunks &X,
169169
ArmInsertedThunks Y) {
@@ -174,28 +174,27 @@ namespace {
174174
struct SLSBLRThunkInserter
175175
: ThunkInserter<SLSBLRThunkInserter, ArmInsertedThunks> {
176176
const char *getThunkPrefix() { return SLSBLRNamePrefix; }
177-
bool mayUseThunk(const MachineFunction &MF,
178-
ArmInsertedThunks InsertedThunks) {
179-
if ((InsertedThunks & ArmThunk &&
180-
!MF.getSubtarget<ARMSubtarget>().isThumb()) ||
181-
(InsertedThunks & ThumbThunk &&
182-
MF.getSubtarget<ARMSubtarget>().isThumb()))
183-
return false;
177+
bool mayUseThunk(const MachineFunction &MF) {
184178
ComdatThunks &= !MF.getSubtarget<ARMSubtarget>().hardenSlsNoComdat();
185-
// FIXME: This could also check if there are any indirect calls in the
186-
// function to more accurately reflect if a thunk will be needed.
187179
return MF.getSubtarget<ARMSubtarget>().hardenSlsBlr();
188180
}
189-
ArmInsertedThunks insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
181+
ArmInsertedThunks insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
182+
ArmInsertedThunks InsertedThunks);
190183
void populateThunk(MachineFunction &MF);
191184

192185
private:
193186
bool ComdatThunks = true;
194187
};
195188
} // namespace
196189

197-
ArmInsertedThunks SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
198-
MachineFunction &MF) {
190+
ArmInsertedThunks
191+
SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
192+
ArmInsertedThunks InsertedThunks) {
193+
if ((InsertedThunks & ArmThunk &&
194+
!MF.getSubtarget<ARMSubtarget>().isThumb()) ||
195+
(InsertedThunks & ThumbThunk &&
196+
MF.getSubtarget<ARMSubtarget>().isThumb()))
197+
return NoThunk;
199198
// FIXME: It probably would be possible to filter which thunks to produce
200199
// based on which registers are actually used in indirect calls in this
201200
// function. But would that be a worthwhile optimization?
@@ -208,6 +207,8 @@ ArmInsertedThunks SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
208207
}
209208

210209
void SLSBLRThunkInserter::populateThunk(MachineFunction &MF) {
210+
assert(MF.getFunction().hasComdat() == ComdatThunks &&
211+
"ComdatThunks value changed since MF creation");
211212
// FIXME: How to better communicate Register number, rather than through
212213
// name and lookup table?
213214
assert(MF.getName().starts_with(getThunkPrefix()));
@@ -384,53 +385,18 @@ FunctionPass *llvm::createARMSLSHardeningPass() {
384385
}
385386

386387
namespace {
387-
class ARMIndirectThunks : public MachineFunctionPass {
388+
class ARMIndirectThunks : public ThunkInserterPass<SLSBLRThunkInserter> {
388389
public:
389390
static char ID;
390391

391-
ARMIndirectThunks() : MachineFunctionPass(ID) {}
392+
ARMIndirectThunks() : ThunkInserterPass(ID) {}
392393

393394
StringRef getPassName() const override { return "ARM Indirect Thunks"; }
394-
395-
bool doInitialization(Module &M) override;
396-
bool runOnMachineFunction(MachineFunction &MF) override;
397-
398-
void getAnalysisUsage(AnalysisUsage &AU) const override {
399-
MachineFunctionPass::getAnalysisUsage(AU);
400-
AU.addRequired<MachineModuleInfoWrapperPass>();
401-
AU.addPreserved<MachineModuleInfoWrapperPass>();
402-
}
403-
404-
private:
405-
std::tuple<SLSBLRThunkInserter> TIs;
406-
407-
template <typename... ThunkInserterT>
408-
static void initTIs(Module &M,
409-
std::tuple<ThunkInserterT...> &ThunkInserters) {
410-
(..., std::get<ThunkInserterT>(ThunkInserters).init(M));
411-
}
412-
template <typename... ThunkInserterT>
413-
static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
414-
std::tuple<ThunkInserterT...> &ThunkInserters) {
415-
return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
416-
}
417395
};
418-
419396
} // end anonymous namespace
420397

421398
char ARMIndirectThunks::ID = 0;
422399

423400
FunctionPass *llvm::createARMIndirectThunks() {
424401
return new ARMIndirectThunks();
425402
}
426-
427-
bool ARMIndirectThunks::doInitialization(Module &M) {
428-
initTIs(M, TIs);
429-
return false;
430-
}
431-
432-
bool ARMIndirectThunks::runOnMachineFunction(MachineFunction &MF) {
433-
LLVM_DEBUG(dbgs() << getPassName() << '\n');
434-
auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
435-
return runTIs(MMI, MF, TIs);
436-
}

0 commit comments

Comments
 (0)