13
13
// ===----------------------------------------------------------------------===//
14
14
15
15
#include " llvm/CodeGen/LivePhysRegs.h"
16
+ #include " llvm/CodeGen/MachineDominators.h"
16
17
#include " llvm/CodeGen/MachineFunctionPass.h"
17
18
#include " llvm/CodeGen/MachineInstrBuilder.h"
19
+ #include " llvm/CodeGen/MachineLoopInfo.h"
18
20
#include " llvm/CodeGen/MachineRegisterInfo.h"
19
21
#include " llvm/IR/Module.h"
20
22
@@ -40,14 +42,14 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
40
42
41
43
bool runOnMachineFunction (MachineFunction &MF) override ;
42
44
43
- std::pair<MachineBasicBlock *, MachineInstr *>
44
- getSecurityCheckerBasicBlock (MachineFunction &MF);
45
+ void getAnalysisUsage (AnalysisUsage &AU) const override ;
45
46
46
- MachineInstr *cloneLoadStackGuard (MachineBasicBlock *CurMBB,
47
- MachineInstr *CheckCall );
47
+ std::pair< MachineInstr *, MachineInstr *>
48
+ findSecurityCheckAndLoadStackGuard (MachineFunction &MF );
48
49
49
- void getGuardCheckSequence (MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
50
- MachineInstr *SeqMI[5 ]);
50
+ MachineInstr *cloneLoadStackGuard (MachineFunction &MF, MachineInstr *MI);
51
+
52
+ bool getGuardCheckSequence (MachineInstr *CheckCall, MachineInstr *SeqMI[5 ]);
51
53
52
54
void finishBlock (MachineBasicBlock *MBB);
53
55
@@ -64,93 +66,113 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
64
66
return new AArch64WinFixupBufferSecurityCheckPass ();
65
67
}
66
68
67
- std::pair<MachineBasicBlock *, MachineInstr *>
68
- AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock (
69
+ void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage (
70
+ AnalysisUsage &AU) const {
71
+ AU.addUsedIfAvailable <MachineDominatorTreeWrapperPass>();
72
+ AU.addPreserved <MachineDominatorTreeWrapperPass>();
73
+ AU.addPreserved <MachineLoopInfoWrapperPass>();
74
+ MachineFunctionPass::getAnalysisUsage (AU);
75
+ }
76
+
77
+ std::pair<MachineInstr *, MachineInstr *>
78
+ AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard (
69
79
MachineFunction &MF) {
80
+
81
+ MachineInstr *SecurityCheckCall = nullptr ;
82
+ MachineInstr *LoadStackGuard = nullptr ;
83
+
70
84
for (auto &MBB : MF) {
71
85
for (auto &MI : MBB) {
86
+ if (!LoadStackGuard && MI.getOpcode () == TargetOpcode::LOAD_STACK_GUARD) {
87
+ LoadStackGuard = &MI;
88
+ }
89
+
72
90
if (MI.isCall () && MI.getNumExplicitOperands () == 1 ) {
73
91
auto MO = MI.getOperand (0 );
74
92
if (MO.isGlobal ()) {
75
93
auto Callee = dyn_cast<Function>(MO.getGlobal ());
76
94
if (Callee && Callee->getName () == " __security_check_cookie" ) {
77
- return std::make_pair (&MBB, &MI) ;
95
+ SecurityCheckCall = &MI;
78
96
}
79
97
}
80
98
}
99
+
100
+ // If both are found, return them
101
+ if (LoadStackGuard && SecurityCheckCall) {
102
+ return std::make_pair (LoadStackGuard, SecurityCheckCall);
103
+ }
81
104
}
82
105
}
106
+
83
107
return std::make_pair (nullptr , nullptr );
84
108
}
85
109
86
- MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard (
87
- MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
88
- // Ensure that we have a valid MachineBasicBlock and CheckCall
89
- if (!CurMBB || !CheckCall)
90
- return nullptr ;
110
+ MachineInstr *
111
+ AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard (MachineFunction &MF,
112
+ MachineInstr *MI) {
91
113
92
- MachineFunction &MF = *CurMBB->getParent ();
114
+ MachineInstr *ClonedInstr = MF.CloneMachineInstr (MI);
115
+
116
+ // Get the register class of the original destination register
117
+ Register OrigReg = MI->getOperand (0 ).getReg ();
93
118
MachineRegisterInfo &MRI = MF.getRegInfo ();
119
+ const TargetRegisterClass *RegClass = MRI.getRegClass (OrigReg);
94
120
95
- // Initialize reverse iterator starting just before CheckCall
96
- MachineBasicBlock::reverse_iterator DIt (CheckCall);
97
- MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend ();
98
-
99
- // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
100
- for (; DIt != DEnd; ++DIt) {
101
- MachineInstr &MI = *DIt;
102
- if (MI.getOpcode () == TargetOpcode::LOAD_STACK_GUARD) {
103
- // Clone the LOAD_STACK_GUARD instruction
104
- MachineInstr *ClonedInstr = MF.CloneMachineInstr (&MI);
105
-
106
- // Get the register class of the original destination register
107
- Register OrigReg = MI.getOperand (0 ).getReg ();
108
- const TargetRegisterClass *RegClass = MRI.getRegClass (OrigReg);
109
-
110
- // Create a new virtual register in the same register class
111
- Register NewReg = MRI.createVirtualRegister (RegClass);
112
-
113
- // Update operand 0 (destination) of the cloned instruction
114
- MachineOperand &DestOperand = ClonedInstr->getOperand (0 );
115
- if (DestOperand.isReg () && DestOperand.isDef ()) {
116
- DestOperand.setReg (NewReg); // Set the new virtual register
117
- }
121
+ // Create a new virtual register in the same register class
122
+ Register NewReg = MRI.createVirtualRegister (RegClass);
118
123
119
- // Return the modified cloned instruction
120
- return ClonedInstr;
121
- }
124
+ // Update operand 0 (destination) of the cloned instruction
125
+ MachineOperand &DestOperand = ClonedInstr->getOperand (0 );
126
+ if (DestOperand.isReg () && DestOperand.isDef ()) {
127
+ DestOperand.setReg (NewReg); // Set the new virtual register
122
128
}
123
129
124
- // If no LOAD_STACK_GUARD instruction was found, return nullptr
125
- return nullptr ;
130
+ return ClonedInstr;
126
131
}
127
132
128
- void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence (
129
- MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
130
- MachineInstr *SeqMI[5 ]) {
133
+ bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence (
134
+ MachineInstr *CheckCall, MachineInstr *SeqMI[5 ]) {
135
+
136
+ MachineBasicBlock *MBB = CheckCall->getParent ();
131
137
132
138
MachineBasicBlock::iterator UIt (CheckCall);
133
139
MachineBasicBlock::reverse_iterator DIt (CheckCall);
134
140
135
141
// Move forward to find the stack adjustment after the call
136
- // to __security_check_cookie
137
142
++UIt;
143
+ if (UIt == MBB->end () || UIt->getOpcode () != AArch64::ADJCALLSTACKUP) {
144
+ return false ;
145
+ }
138
146
SeqMI[4 ] = &*UIt;
139
147
140
148
// Assign the BL instruction (call to __security_check_cookie)
141
149
SeqMI[3 ] = CheckCall;
142
150
143
- // COPY function slot cookie
151
+ // Move backward to find the COPY instruction for the function slot cookie
152
+ // argument passing
144
153
++DIt;
154
+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::COPY) {
155
+ return false ;
156
+ }
145
157
SeqMI[2 ] = &*DIt;
146
158
147
159
// Move backward to find the instruction that loads the security cookie from
148
160
// the stack
149
161
++DIt;
162
+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::LDRXui) {
163
+ return false ;
164
+ }
150
165
SeqMI[1 ] = &*DIt;
151
166
152
- ++DIt; // Find ADJCALLSTACKDOWN
167
+ // Move backward to find the stack adjustment before the call
168
+ ++DIt;
169
+ if (DIt == MBB->rend () || DIt->getOpcode () != AArch64::ADJCALLSTACKDOWN) {
170
+ return false ;
171
+ }
153
172
SeqMI[0 ] = &*DIt;
173
+
174
+ // If all instructions are matched and stored, the sequence is valid
175
+ return true ;
154
176
}
155
177
156
178
void AArch64WinFixupBufferSecurityCheckPass::finishBlock (
@@ -185,21 +207,23 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
185
207
if (!GV)
186
208
return Changed;
187
209
188
- const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
189
-
190
- // Check if security check cookie call was installed or not
191
- auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock (MF);
192
- if (!CheckCall)
210
+ // Find LOAD_STACK_GUARD and __security_check_cookie instructions
211
+ auto [StackGuard, CheckCall] = findSecurityCheckAndLoadStackGuard (MF);
212
+ if (!CheckCall || !StackGuard)
193
213
return Changed;
194
214
195
- // Get sequence of instruction in CurMBB responsible for calling
215
+ // Get sequence of instructions in current basic block responsible for calling
196
216
// __security_check_cookie
197
217
MachineInstr *SeqMI[5 ];
198
- getGuardCheckSequence (CurMBB, CheckCall, SeqMI);
218
+ if (!getGuardCheckSequence (CheckCall, SeqMI))
219
+ return Changed;
220
+
221
+ const TargetInstrInfo *TII = MF.getSubtarget ().getInstrInfo ();
222
+ MachineBasicBlock *CurMBB = CheckCall->getParent ();
199
223
200
224
// Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
201
225
// instruction with new destination register
202
- MachineInstr *ClonedInstr = cloneLoadStackGuard (CurMBB, CheckCall );
226
+ MachineInstr *ClonedInstr = cloneLoadStackGuard (MF, StackGuard );
203
227
if (!ClonedInstr)
204
228
return Changed;
205
229
@@ -216,13 +240,14 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
216
240
CurMBB->splice (InsertPt, CurMBB, std::next (InsertPt));
217
241
218
242
// Create a new virtual register for the CMP instruction result
219
- Register DiscardReg =
220
- MF. getRegInfo () .createVirtualRegister (&AArch64::GPR64RegClass);
243
+ MachineRegisterInfo &MRI = MF. getRegInfo ();
244
+ Register DiscardReg = MRI .createVirtualRegister (&AArch64::GPR64RegClass);
221
245
222
246
// Emit the CMP instruction to compare stack cookie with global cookie
223
247
BuildMI (*CurMBB, InsertPt, DebugLoc (), TII->get (AArch64::SUBSXrr))
224
- .addReg (DiscardReg, RegState::Define | RegState::Dead) // Result discarded
225
- .addReg (CookieLoadReg) // First operand: stack cookie
248
+ .addReg (DiscardReg,
249
+ RegState::Define | RegState::Dead) // Result discarded
250
+ .addReg (CookieLoadReg) // First operand: stack cookie
226
251
.addReg (GlobalCookieReg); // Second operand: global cookie
227
252
228
253
// Create FailMBB basic block to call __security_check_cookie
@@ -258,6 +283,15 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
258
283
CurMBB->addSuccessor (NewRetMBB);
259
284
CurMBB->addSuccessor (FailMBB);
260
285
286
+ MachineDominatorTreeWrapperPass *WrapperPass =
287
+ getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
288
+ MachineDominatorTree *MDT =
289
+ WrapperPass ? &WrapperPass->getDomTree () : nullptr ;
290
+ if (MDT) {
291
+ MDT->addNewBlock (FailMBB, CurMBB);
292
+ MDT->addNewBlock (NewRetMBB, CurMBB);
293
+ }
294
+
261
295
finishFunction (FailMBB, NewRetMBB);
262
296
263
297
return !Changed;
0 commit comments