Skip to content

Commit 8c9d12c

Browse files
authored
Merge pull request #67396 from compnerd/ejection
swift-inspect: ensure that we eject any injected code
2 parents 8cb61d1 + bc9e90d commit 8c9d12c

File tree

1 file changed

+38
-40
lines changed

1 file changed

+38
-40
lines changed

tools/swift-inspect/Sources/swift-inspect/WindowsRemoteProcess.swift

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,13 @@ internal final class WindowsRemoteProcess: RemoteProcess {
161161
self.context = context
162162

163163
// Locate swiftCore.dll in the target process and load modules.
164-
iterateRemoteModules(
165-
dwProcessId: processId,
166-
closure: { (entry, module) in
167-
// FIXME(compnerd) support static linking at some point
168-
if module == "swiftCore.dll" {
169-
self.hSwiftCore = entry.hModule
170-
}
171-
_ = swift_reflection_addImage(
172-
context, unsafeBitCast(entry.modBaseAddr, to: swift_addr_t.self))
173-
})
164+
modules(of: processId) { (entry, module) in
165+
// FIXME(compnerd) support static linking at some point
166+
if module == "swiftCore.dll" { self.hSwiftCore = entry.hModule }
167+
_ = swift_reflection_addImage(context,
168+
unsafeBitCast(entry.modBaseAddr,
169+
to: swift_addr_t.self))
170+
}
174171
if self.hSwiftCore == HMODULE(bitPattern: -1) {
175172
// FIXME(compnerd) log error
176173
return nil
@@ -321,7 +318,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
321318
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
322319
return
323320
}
324-
let (loadLibraryAddr, freeLibraryAddr) = (remoteAddrs[0], remoteAddrs[1])
321+
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])
325322
let hThread: HANDLE = CreateRemoteThread(
326323
self.process, nil, 0, loadLibraryAddr,
327324
dllPathRemote, 0, nil)
@@ -331,6 +328,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
331328
}
332329
defer { CloseHandle(hThread) }
333330

331+
defer {
332+
// Always perform the code ejection process even if the heap walk fails.
333+
// The module cannot re-execute the heap walk and will leave a retain
334+
// count behind, preventing the module from being unlinked on the file
335+
// system as well as leave code in the inspected process. This will
336+
// eventually be an issue for treating the injected code as a resource
337+
// which is extracted temporarily.
338+
if !eject(module: dllPathRemote, from: dwProcessId, pfnFreeLibrary) {
339+
print("Failed to unload the remote dll")
340+
}
341+
}
342+
334343
// The main heap iteration loop.
335344
outer: while true {
336345
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
@@ -381,14 +390,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
381390
print("LoadLibraryW failed \(threadExitCode)")
382391
return
383392
}
384-
385-
// Unload the dll and deallocate the dll path from the remote process
386-
if !unloadDllAndPathRemote(
387-
dwProcessId: dwProcessId, dllPathRemote: dllPathRemote, freeLibraryAddr: freeLibraryAddr)
388-
{
389-
print("Failed to unload the remote dll")
390-
return
391-
}
392393
}
393394

394395
private func allocateDllPathRemote() -> UnsafeMutableRawPointer? {
@@ -430,26 +431,27 @@ internal final class WindowsRemoteProcess: RemoteProcess {
430431
}
431432
}
432433

433-
private func unloadDllAndPathRemote(
434-
dwProcessId: DWORD, dllPathRemote: UnsafeMutableRawPointer,
435-
freeLibraryAddr: LPTHREAD_START_ROUTINE
436-
) -> Bool {
434+
/// Eject the injected code from the instrumented process.
435+
///
436+
/// Performs the necessary clean up to remove the injected code from the
437+
/// instrumented process once the heap walk is complete.
438+
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
439+
from dwProcessId: DWORD,
440+
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
437441
// Get the dll module handle in the remote process to use it to
438442
// unload it below.
443+
439444
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
440445
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
441446
// So, search for it using the snapshot instead.
442-
guard
443-
let hDllModule = findRemoteModule(
444-
dwProcessId: dwProcessId, moduleName: "SwiftInspectClient.dll")
445-
else {
447+
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
446448
print("Failed to find the client dll")
447449
return false
448450
}
449451
// Unload the dll from the remote process
450452
let hUnloadThread = CreateRemoteThread(
451453
self.process, nil, 0, freeLibraryAddr,
452-
UnsafeMutableRawPointer(hDllModule), 0, nil)
454+
UnsafeMutableRawPointer(hModule), 0, nil)
453455
if hUnloadThread == HANDLE(bitPattern: 0) {
454456
print("CreateRemoteThread for unload failed \(GetLastError())")
455457
return false
@@ -476,7 +478,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
476478
return true
477479
}
478480

479-
private func iterateRemoteModules(dwProcessId: DWORD, closure: (MODULEENTRY32W, String) -> Void) {
481+
private func modules(of dwProcessId: DWORD, _ closure: (MODULEENTRY32W, String) -> Void) {
480482
let hModuleSnapshot: HANDLE =
481483
CreateToolhelp32Snapshot(DWORD(TH32CS_SNAPMODULE), dwProcessId)
482484
if hModuleSnapshot == INVALID_HANDLE_VALUE {
@@ -503,22 +505,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
503505
} while Module32NextW(hModuleSnapshot, &entry)
504506
}
505507

506-
private func findRemoteModule(dwProcessId: DWORD, moduleName: String) -> HMODULE? {
507-
var hDllModule: HMODULE? = nil
508-
iterateRemoteModules(
509-
dwProcessId: dwProcessId,
510-
closure: { (entry, module) in
511-
if module == moduleName {
512-
hDllModule = entry.hModule
513-
}
514-
})
515-
return hDllModule
508+
private func find(module named: String, in dwProcessId: DWORD) -> HMODULE? {
509+
var hModule: HMODULE?
510+
modules(of: dwProcessId) { (entry, module) in
511+
if module == named { hModule = entry.hModule }
512+
}
513+
return hModule
516514
}
517515

518516
private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
519517
-> [LPTHREAD_START_ROUTINE]?
520518
{
521-
guard let hDllModule = findRemoteModule(dwProcessId: dwProcessId, moduleName: moduleName) else {
519+
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
522520
print("Failed to find remote module \(moduleName)")
523521
return nil
524522
}

0 commit comments

Comments
 (0)