Skip to content

Commit 1d38f27

Browse files
committed
swift-inspect: clean up some of the code in WindowsRemoteProcess
Add some vertical whitespace to the code ejection process. Alter the logic to clean up the memory allocation first, ignoring the error as the subsequent run will perform a new allocation and this will leak a fixed amount of memory without interrupting the process or use of the tool. No longer check the exit code of the thread as that is always guaranteed to be 0 as the module unloading path does not report any error code in the injected code (DLL). Use the opportunity to do some simple renaming to improve the readability and create an overload for avoiding unnecessary ceremony around use of a shared constant.
1 parent d8fc855 commit 1d38f27

File tree

1 file changed

+37
-52
lines changed

1 file changed

+37
-52
lines changed

tools/swift-inspect/Sources/swift-inspect/WindowsRemoteProcess.swift

Lines changed: 37 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ import SwiftRemoteMirror
1717
import Foundation
1818
import SwiftInspectClientInterface
1919

20+
internal var WAIT_TIMEOUT_MS: DWORD {
21+
DWORD(SwiftInspectClientInterface.WAIT_TIMEOUT_MS)
22+
}
23+
2024
internal final class WindowsRemoteProcess: RemoteProcess {
2125
public typealias ProcessIdentifier = DWORD
2226
public typealias ProcessHandle = HANDLE
@@ -268,7 +272,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
268272
// memory and event objects
269273
let bufSize = Int(BUF_SIZE)
270274
let sharedMemoryName = "\(SHARED_MEM_NAME_PREFIX)-\(String(dwProcessId))"
271-
let waitTimeoutMs = DWORD(WAIT_TIMEOUT_MS)
272275

273276
// Set up the shared memory
274277
let hMapFile = CreateFileMappingA(
@@ -309,19 +312,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
309312
return
310313
}
311314

312-
// Load the dll and start the heap walk
313-
guard
314-
let remoteAddrs = findRemoteAddresses(
315-
dwProcessId: dwProcessId, moduleName: "KERNEL32.DLL",
316-
symbols: ["LoadLibraryW", "FreeLibrary"])
317-
else {
315+
guard let aEntryPoints = find(module: "KERNEL32.DLL",
316+
symbols: ["LoadLibraryW", "FreeLibrary"],
317+
in: dwProcessId)?.map({
318+
unsafeBitCast($0, to: LPTHREAD_START_ROUTINE.self)
319+
}) else {
318320
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
319321
return
320322
}
321-
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])
323+
324+
let (pfnLoadLibraryW, pfnFreeLibrary) = (aEntryPoints[0], aEntryPoints[1])
322325
let hThread: HANDLE = CreateRemoteThread(
323-
self.process, nil, 0, loadLibraryAddr,
324-
dllPathRemote, 0, nil)
326+
self.process, nil, 0, pfnLoadLibraryW, dllPathRemote, 0, nil)
325327
if hThread == HANDLE(bitPattern: 0) {
326328
print("CreateRemoteThread failed \(GetLastError())")
327329
return
@@ -342,7 +344,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
342344

343345
// The main heap iteration loop.
344346
outer: while true {
345-
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
347+
let wait = WaitForSingleObject(hReadEvent, WAIT_TIMEOUT_MS)
346348
if wait != WAIT_OBJECT_0 {
347349
print("WaitForSingleObject failed \(wait)")
348350
return
@@ -378,7 +380,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
378380
}
379381
}
380382

381-
let wait = WaitForSingleObject(hThread, waitTimeoutMs)
383+
let wait = WaitForSingleObject(hThread, WAIT_TIMEOUT_MS)
382384
if wait != WAIT_OBJECT_0 {
383385
print("WaitForSingleObject on LoadLibrary failed \(wait)")
384386
return
@@ -435,43 +437,33 @@ internal final class WindowsRemoteProcess: RemoteProcess {
435437
///
436438
/// Performs the necessary clean up to remove the injected code from the
437439
/// instrumented process once the heap walk is complete.
438-
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
439-
from dwProcessId: DWORD,
440-
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
441-
// Get the dll module handle in the remote process to use it to
442-
// unload it below.
443-
444-
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
445-
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
446-
// So, search for it using the snapshot instead.
447-
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
440+
private func eject(module pwszModule: UnsafeMutableRawPointer,
441+
from process: DWORD,
442+
_ pfnFunction: LPTHREAD_START_ROUTINE) -> Bool {
443+
// Deallocate the dll path string in the remote process
444+
if !VirtualFreeEx(self.process, pwszModule, 0, DWORD(MEM_RELEASE)) {
445+
print("VirtualFreeEx failed: \(GetLastError())")
446+
}
447+
448+
// Get the dll module handle in the remote process to use it to unload it
449+
// below. `GetExitCodeThread` returns a `DWORD` (32-bit) but the `HMODULE`
450+
// pointer-sized and may be truncated, so, search for it using the snapshot
451+
// instead.
452+
guard let hModule = find(module: "SwiftInspectClient.dll", in: process) else {
448453
print("Failed to find the client dll")
449454
return false
450455
}
456+
451457
// Unload the dll from the remote process
452-
let hUnloadThread = CreateRemoteThread(
453-
self.process, nil, 0, freeLibraryAddr,
454-
UnsafeMutableRawPointer(hModule), 0, nil)
455-
if hUnloadThread == HANDLE(bitPattern: 0) {
458+
guard let hThread = CreateRemoteThread(self.process, nil, 0, pfnFunction,
459+
hModule, 0, nil) else {
456460
print("CreateRemoteThread for unload failed \(GetLastError())")
457461
return false
458462
}
459-
defer { CloseHandle(hUnloadThread) }
460-
let unload_wait = WaitForSingleObject(hUnloadThread, DWORD(WAIT_TIMEOUT_MS))
461-
if unload_wait != WAIT_OBJECT_0 {
462-
print("WaitForSingleObject on FreeLibrary failed \(unload_wait)")
463-
return false
464-
}
465-
var unloadExitCode: DWORD = 0
466-
GetExitCodeThread(hUnloadThread, &unloadExitCode)
467-
if unloadExitCode == 0 {
468-
print("FreeLibrary failed")
469-
return false
470-
}
463+
defer { CloseHandle(hThread) }
471464

472-
// Deallocate the dll path string in the remote process
473-
if !VirtualFreeEx(self.process, dllPathRemote, 0, DWORD(MEM_RELEASE)) {
474-
print("VirtualFreeEx failed GLE=\(GetLastError())")
465+
guard WaitForSingleObject(hThread, WAIT_TIMEOUT_MS) == WAIT_OBJECT_0 else {
466+
print("WaitForSingleObject on FreeLibrary failed \(GetLastError())")
475467
return false
476468
}
477469

@@ -513,19 +505,12 @@ internal final class WindowsRemoteProcess: RemoteProcess {
513505
return hModule
514506
}
515507

516-
private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
517-
-> [LPTHREAD_START_ROUTINE]?
518-
{
519-
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
520-
print("Failed to find remote module \(moduleName)")
508+
private func find(module: String, symbols: [String], in process: DWORD) -> [FARPROC]? {
509+
guard let hModule = find(module: module, in: process) else {
510+
print("Failed to find remote module \(module)")
521511
return nil
522512
}
523-
var addresses: [LPTHREAD_START_ROUTINE] = []
524-
for sym in symbols {
525-
addresses.append(
526-
unsafeBitCast(GetProcAddress(hDllModule, sym), to: LPTHREAD_START_ROUTINE.self))
527-
}
528-
return addresses
513+
return symbols.map { GetProcAddress(hModule, $0) }
529514
}
530515

531516
private func createEventPair(_ dwProcessId: DWORD) -> (HANDLE, HANDLE)? {

0 commit comments

Comments
 (0)