Skip to content

Commit bc9e90d

Browse files
committed
swift-inspect: ensure that we eject any injected code
We would previously fail to eject the injected code on a failure. This would prevent a future introspection into the process as well as leave the file open with an incremented retain count in the kernel space which would prevent the file from being deleted. In the future, when the application is able to treat the injected code as a resource, this resource would be temporarily extracted, but would no longer be possible to delete until a reboot (with a registration of the deletion) due to the retained code. Take the opportunity to rename some functions to take advantage of labelled parameters and trailing function syntax. This makes the code a small amount easier to read.
1 parent 19f4b6f commit bc9e90d

File tree

1 file changed

+38
-40
lines changed

1 file changed

+38
-40
lines changed

tools/swift-inspect/Sources/swift-inspect/WindowsRemoteProcess.swift

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,13 @@ internal final class WindowsRemoteProcess: RemoteProcess {
161161
self.context = context
162162

163163
// Locate swiftCore.dll in the target process and load modules.
164-
iterateRemoteModules(
165-
dwProcessId: processId,
166-
closure: { (entry, module) in
167-
// FIXME(compnerd) support static linking at some point
168-
if module == "swiftCore.dll" {
169-
self.hSwiftCore = entry.hModule
170-
}
171-
_ = swift_reflection_addImage(
172-
context, unsafeBitCast(entry.modBaseAddr, to: swift_addr_t.self))
173-
})
164+
modules(of: processId) { (entry, module) in
165+
// FIXME(compnerd) support static linking at some point
166+
if module == "swiftCore.dll" { self.hSwiftCore = entry.hModule }
167+
_ = swift_reflection_addImage(context,
168+
unsafeBitCast(entry.modBaseAddr,
169+
to: swift_addr_t.self))
170+
}
174171
if self.hSwiftCore == HMODULE(bitPattern: -1) {
175172
// FIXME(compnerd) log error
176173
return nil
@@ -321,7 +318,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
321318
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
322319
return
323320
}
324-
let (loadLibraryAddr, freeLibraryAddr) = (remoteAddrs[0], remoteAddrs[1])
321+
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])
325322
let hThread: HANDLE = CreateRemoteThread(
326323
self.process, nil, 0, loadLibraryAddr,
327324
dllPathRemote, 0, nil)
@@ -331,6 +328,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
331328
}
332329
defer { CloseHandle(hThread) }
333330

331+
defer {
332+
// Always perform the code ejection process even if the heap walk fails.
333+
// The module cannot re-execute the heap walk and will leave a retain
334+
// count behind, preventing the module from being unlinked on the file
335+
// system as well as leave code in the inspected process. This will
336+
// eventually be an issue for treating the injected code as a resource
337+
// which is extracted temporarily.
338+
if !eject(module: dllPathRemote, from: dwProcessId, pfnFreeLibrary) {
339+
print("Failed to unload the remote dll")
340+
}
341+
}
342+
334343
// The main heap iteration loop.
335344
outer: while true {
336345
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
@@ -381,14 +390,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
381390
print("LoadLibraryW failed \(threadExitCode)")
382391
return
383392
}
384-
385-
// Unload the dll and deallocate the dll path from the remote process
386-
if !unloadDllAndPathRemote(
387-
dwProcessId: dwProcessId, dllPathRemote: dllPathRemote, freeLibraryAddr: freeLibraryAddr)
388-
{
389-
print("Failed to unload the remote dll")
390-
return
391-
}
392393
}
393394

394395
private func allocateDllPathRemote() -> UnsafeMutableRawPointer? {
@@ -430,26 +431,27 @@ internal final class WindowsRemoteProcess: RemoteProcess {
430431
}
431432
}
432433

433-
private func unloadDllAndPathRemote(
434-
dwProcessId: DWORD, dllPathRemote: UnsafeMutableRawPointer,
435-
freeLibraryAddr: LPTHREAD_START_ROUTINE
436-
) -> Bool {
434+
/// Eject the injected code from the instrumented process.
435+
///
436+
/// Performs the necessary clean up to remove the injected code from the
437+
/// instrumented process once the heap walk is complete.
438+
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
439+
from dwProcessId: DWORD,
440+
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
437441
// Get the dll module handle in the remote process to use it to
438442
// unload it below.
443+
439444
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
440445
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
441446
// So, search for it using the snapshot instead.
442-
guard
443-
let hDllModule = findRemoteModule(
444-
dwProcessId: dwProcessId, moduleName: "SwiftInspectClient.dll")
445-
else {
447+
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
446448
print("Failed to find the client dll")
447449
return false
448450
}
449451
// Unload the dll from the remote process
450452
let hUnloadThread = CreateRemoteThread(
451453
self.process, nil, 0, freeLibraryAddr,
452-
UnsafeMutableRawPointer(hDllModule), 0, nil)
454+
UnsafeMutableRawPointer(hModule), 0, nil)
453455
if hUnloadThread == HANDLE(bitPattern: 0) {
454456
print("CreateRemoteThread for unload failed \(GetLastError())")
455457
return false
@@ -476,7 +478,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
476478
return true
477479
}
478480

479-
private func iterateRemoteModules(dwProcessId: DWORD, closure: (MODULEENTRY32W, String) -> Void) {
481+
private func modules(of dwProcessId: DWORD, _ closure: (MODULEENTRY32W, String) -> Void) {
480482
let hModuleSnapshot: HANDLE =
481483
CreateToolhelp32Snapshot(DWORD(TH32CS_SNAPMODULE), dwProcessId)
482484
if hModuleSnapshot == INVALID_HANDLE_VALUE {
@@ -503,22 +505,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
503505
} while Module32NextW(hModuleSnapshot, &entry)
504506
}
505507

506-
private func findRemoteModule(dwProcessId: DWORD, moduleName: String) -> HMODULE? {
507-
var hDllModule: HMODULE? = nil
508-
iterateRemoteModules(
509-
dwProcessId: dwProcessId,
510-
closure: { (entry, module) in
511-
if module == moduleName {
512-
hDllModule = entry.hModule
513-
}
514-
})
515-
return hDllModule
508+
private func find(module named: String, in dwProcessId: DWORD) -> HMODULE? {
509+
var hModule: HMODULE?
510+
modules(of: dwProcessId) { (entry, module) in
511+
if module == named { hModule = entry.hModule }
512+
}
513+
return hModule
516514
}
517515

518516
private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
519517
-> [LPTHREAD_START_ROUTINE]?
520518
{
521-
guard let hDllModule = findRemoteModule(dwProcessId: dwProcessId, moduleName: moduleName) else {
519+
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
522520
print("Failed to find remote module \(moduleName)")
523521
return nil
524522
}

0 commit comments

Comments
 (0)