Skip to content

Commit ddf5bf0

Browse files
committed
Fix review comments by Eli
- Fix LOAD_STACK_GUARD detection logic - Make transform run when security check sequence matched - Preserve dominator tree
1 parent 31e6c78 commit ddf5bf0

File tree

2 files changed

+96
-63
lines changed

2 files changed

+96
-63
lines changed

llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp

Lines changed: 95 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "llvm/CodeGen/LivePhysRegs.h"
16+
#include "llvm/CodeGen/MachineDominators.h"
1617
#include "llvm/CodeGen/MachineFunctionPass.h"
1718
#include "llvm/CodeGen/MachineInstrBuilder.h"
19+
#include "llvm/CodeGen/MachineLoopInfo.h"
1820
#include "llvm/CodeGen/MachineRegisterInfo.h"
1921
#include "llvm/IR/Module.h"
2022

@@ -40,14 +42,14 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
4042

4143
bool runOnMachineFunction(MachineFunction &MF) override;
4244

43-
std::pair<MachineBasicBlock *, MachineInstr *>
44-
getSecurityCheckerBasicBlock(MachineFunction &MF);
45+
void getAnalysisUsage(AnalysisUsage &AU) const override;
4546

46-
MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
47-
MachineInstr *CheckCall);
47+
std::pair<MachineInstr *, MachineInstr *>
48+
findSecurityCheckAndLoadStackGuard(MachineFunction &MF);
4849

49-
void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
50-
MachineInstr *SeqMI[5]);
50+
MachineInstr *cloneLoadStackGuard(MachineFunction &MF, MachineInstr *MI);
51+
52+
bool getGuardCheckSequence(MachineInstr *CheckCall, MachineInstr *SeqMI[5]);
5153

5254
void finishBlock(MachineBasicBlock *MBB);
5355

@@ -64,93 +66,113 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
6466
return new AArch64WinFixupBufferSecurityCheckPass();
6567
}
6668

67-
std::pair<MachineBasicBlock *, MachineInstr *>
68-
AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
69+
void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage(
70+
AnalysisUsage &AU) const {
71+
AU.addUsedIfAvailable<MachineDominatorTreeWrapperPass>();
72+
AU.addPreserved<MachineDominatorTreeWrapperPass>();
73+
AU.addPreserved<MachineLoopInfoWrapperPass>();
74+
MachineFunctionPass::getAnalysisUsage(AU);
75+
}
76+
77+
std::pair<MachineInstr *, MachineInstr *>
78+
AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard(
6979
MachineFunction &MF) {
80+
81+
MachineInstr *SecurityCheckCall = nullptr;
82+
MachineInstr *LoadStackGuard = nullptr;
83+
7084
for (auto &MBB : MF) {
7185
for (auto &MI : MBB) {
86+
if (!LoadStackGuard && MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
87+
LoadStackGuard = &MI;
88+
}
89+
7290
if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
7391
auto MO = MI.getOperand(0);
7492
if (MO.isGlobal()) {
7593
auto Callee = dyn_cast<Function>(MO.getGlobal());
7694
if (Callee && Callee->getName() == "__security_check_cookie") {
77-
return std::make_pair(&MBB, &MI);
95+
SecurityCheckCall = &MI;
7896
}
7997
}
8098
}
99+
100+
// If both are found, return them
101+
if (LoadStackGuard && SecurityCheckCall) {
102+
return std::make_pair(LoadStackGuard, SecurityCheckCall);
103+
}
81104
}
82105
}
106+
83107
return std::make_pair(nullptr, nullptr);
84108
}
85109

86-
MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
87-
MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
88-
// Ensure that we have a valid MachineBasicBlock and CheckCall
89-
if (!CurMBB || !CheckCall)
90-
return nullptr;
110+
MachineInstr *
111+
AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(MachineFunction &MF,
112+
MachineInstr *MI) {
91113

92-
MachineFunction &MF = *CurMBB->getParent();
114+
MachineInstr *ClonedInstr = MF.CloneMachineInstr(MI);
115+
116+
// Get the register class of the original destination register
117+
Register OrigReg = MI->getOperand(0).getReg();
93118
MachineRegisterInfo &MRI = MF.getRegInfo();
119+
const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
94120

95-
// Initialize reverse iterator starting just before CheckCall
96-
MachineBasicBlock::reverse_iterator DIt(CheckCall);
97-
MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
98-
99-
// Reverse iterate from CheckCall to find LOAD_STACK_GUARD
100-
for (; DIt != DEnd; ++DIt) {
101-
MachineInstr &MI = *DIt;
102-
if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
103-
// Clone the LOAD_STACK_GUARD instruction
104-
MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
105-
106-
// Get the register class of the original destination register
107-
Register OrigReg = MI.getOperand(0).getReg();
108-
const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
109-
110-
// Create a new virtual register in the same register class
111-
Register NewReg = MRI.createVirtualRegister(RegClass);
112-
113-
// Update operand 0 (destination) of the cloned instruction
114-
MachineOperand &DestOperand = ClonedInstr->getOperand(0);
115-
if (DestOperand.isReg() && DestOperand.isDef()) {
116-
DestOperand.setReg(NewReg); // Set the new virtual register
117-
}
121+
// Create a new virtual register in the same register class
122+
Register NewReg = MRI.createVirtualRegister(RegClass);
118123

119-
// Return the modified cloned instruction
120-
return ClonedInstr;
121-
}
124+
// Update operand 0 (destination) of the cloned instruction
125+
MachineOperand &DestOperand = ClonedInstr->getOperand(0);
126+
if (DestOperand.isReg() && DestOperand.isDef()) {
127+
DestOperand.setReg(NewReg); // Set the new virtual register
122128
}
123129

124-
// If no LOAD_STACK_GUARD instruction was found, return nullptr
125-
return nullptr;
130+
return ClonedInstr;
126131
}
127132

128-
void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
129-
MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
130-
MachineInstr *SeqMI[5]) {
133+
bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
134+
MachineInstr *CheckCall, MachineInstr *SeqMI[5]) {
135+
136+
MachineBasicBlock *MBB = CheckCall->getParent();
131137

132138
MachineBasicBlock::iterator UIt(CheckCall);
133139
MachineBasicBlock::reverse_iterator DIt(CheckCall);
134140

135141
// Move forward to find the stack adjustment after the call
136-
// to __security_check_cookie
137142
++UIt;
143+
if (UIt == MBB->end() || UIt->getOpcode() != AArch64::ADJCALLSTACKUP) {
144+
return false;
145+
}
138146
SeqMI[4] = &*UIt;
139147

140148
// Assign the BL instruction (call to __security_check_cookie)
141149
SeqMI[3] = CheckCall;
142150

143-
// COPY function slot cookie
151+
// Move backward to find the COPY instruction for the function slot cookie
152+
// argument passing
144153
++DIt;
154+
if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::COPY) {
155+
return false;
156+
}
145157
SeqMI[2] = &*DIt;
146158

147159
// Move backward to find the instruction that loads the security cookie from
148160
// the stack
149161
++DIt;
162+
if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::LDRXui) {
163+
return false;
164+
}
150165
SeqMI[1] = &*DIt;
151166

152-
++DIt; // Find ADJCALLSTACKDOWN
167+
// Move backward to find the stack adjustment before the call
168+
++DIt;
169+
if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::ADJCALLSTACKDOWN) {
170+
return false;
171+
}
153172
SeqMI[0] = &*DIt;
173+
174+
// If all instructions are matched and stored, the sequence is valid
175+
return true;
154176
}
155177

156178
void AArch64WinFixupBufferSecurityCheckPass::finishBlock(
@@ -185,21 +207,23 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
185207
if (!GV)
186208
return Changed;
187209

188-
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
189-
190-
// Check if security check cookie call was installed or not
191-
auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
192-
if (!CheckCall)
210+
// Find LOAD_STACK_GUARD and __security_check_cookie instructions
211+
auto [StackGuard, CheckCall] = findSecurityCheckAndLoadStackGuard(MF);
212+
if (!CheckCall || !StackGuard)
193213
return Changed;
194214

195-
// Get sequence of instruction in CurMBB responsible for calling
215+
// Get sequence of instructions in current basic block responsible for calling
196216
// __security_check_cookie
197217
MachineInstr *SeqMI[5];
198-
getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
218+
if (!getGuardCheckSequence(CheckCall, SeqMI))
219+
return Changed;
220+
221+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
222+
MachineBasicBlock *CurMBB = CheckCall->getParent();
199223

200224
// Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
201225
// instruction with new destination register
202-
MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
226+
MachineInstr *ClonedInstr = cloneLoadStackGuard(MF, StackGuard);
203227
if (!ClonedInstr)
204228
return Changed;
205229

@@ -216,13 +240,14 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
216240
CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
217241

218242
// Create a new virtual register for the CMP instruction result
219-
Register DiscardReg =
220-
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
243+
MachineRegisterInfo &MRI = MF.getRegInfo();
244+
Register DiscardReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
221245

222246
// Emit the CMP instruction to compare stack cookie with global cookie
223247
BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
224-
.addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
225-
.addReg(CookieLoadReg) // First operand: stack cookie
248+
.addReg(DiscardReg,
249+
RegState::Define | RegState::Dead) // Result discarded
250+
.addReg(CookieLoadReg) // First operand: stack cookie
226251
.addReg(GlobalCookieReg); // Second operand: global cookie
227252

228253
// Create FailMBB basic block to call __security_check_cookie
@@ -258,6 +283,15 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
258283
CurMBB->addSuccessor(NewRetMBB);
259284
CurMBB->addSuccessor(FailMBB);
260285

286+
MachineDominatorTreeWrapperPass *WrapperPass =
287+
getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
288+
MachineDominatorTree *MDT =
289+
WrapperPass ? &WrapperPass->getDomTree() : nullptr;
290+
if (MDT) {
291+
MDT->addNewBlock(FailMBB, CurMBB);
292+
MDT->addNewBlock(NewRetMBB, CurMBB);
293+
}
294+
261295
finishFunction(FailMBB, NewRetMBB);
262296

263297
return !Changed;

llvm/test/CodeGen/AArch64/O3-pipeline.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,9 @@
168168
; CHECK-NEXT: Process Implicit Definitions
169169
; CHECK-NEXT: Remove unreachable machine basic blocks
170170
; CHECK-NEXT: Live Variable Analysis
171-
; CHECK-NEXT: MachineDominator Tree Construction
172-
; CHECK-NEXT: Machine Natural Loop Construction
173171
; CHECK-NEXT: Eliminate PHI nodes for register allocation
174172
; CHECK-NEXT: Two-Address instruction pass
173+
; CHECK-NEXT: MachineDominator Tree Construction
175174
; CHECK-NEXT: Slot index numbering
176175
; CHECK-NEXT: Live Interval Analysis
177176
; CHECK-NEXT: Register Coalescer

0 commit comments

Comments
 (0)