Skip to content

Commit 7a6c6c7

Browse files
committed
swift-inspect: clean up some of the code in WindowsRemoteProcess
Add some vertical whitespace to the code ejection process. Alter the logic to clean up the memory allocation first, ignoring the error as the subsequent run will perform a new allocation and this will leak a fixed amount of memory without interrupting the process or use of the tool. No longer check the exit code of the thread as that is always guaranteed to be 0 as the module unloading path does not report any error code in the injected code (DLL). Use the opportunity to do some simple renaming to improve the readability and create an overload for avoiding unnecessary ceremony around use of a shared constant.
1 parent 556c503 commit 7a6c6c7

File tree

1 file changed

+44
-48
lines changed

1 file changed

+44
-48
lines changed

tools/swift-inspect/Sources/swift-inspect/WindowsRemoteProcess.swift

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ import SwiftRemoteMirror
1717
import Foundation
1818
import SwiftInspectClientInterface
1919

20+
internal var WAIT_TIMEOUT_MS: DWORD {
21+
DWORD(SwiftInspectClientInterface.WAIT_TIMEOUT_MS)
22+
}
23+
2024
internal final class WindowsRemoteProcess: RemoteProcess {
2125
public typealias ProcessIdentifier = DWORD
2226
public typealias ProcessHandle = HANDLE
@@ -271,7 +275,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
271275
// memory and event objects
272276
let bufSize = Int(BUF_SIZE)
273277
let sharedMemoryName = "\(SHARED_MEM_NAME_PREFIX)-\(String(dwProcessId))"
274-
let waitTimeoutMs = DWORD(WAIT_TIMEOUT_MS)
275278

276279
// Set up the shared memory
277280
let hMapFile = CreateFileMappingA(
@@ -312,19 +315,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
312315
return
313316
}
314317

315-
// Load the dll and start the heap walk
316-
guard
317-
let remoteAddrs = findRemoteAddresses(
318-
dwProcessId: dwProcessId, moduleName: "KERNEL32.DLL",
319-
symbols: ["LoadLibraryW", "FreeLibrary"])
320-
else {
318+
guard let aEntryPoints = find(module: "KERNEL32.DLL",
319+
symbols: ["LoadLibraryW", "FreeLibrary"],
320+
in: dwProcessId)?.map({
321+
unsafeBitCast($0, to: LPTHREAD_START_ROUTINE.self)
322+
}) else {
321323
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
322324
return
323325
}
324-
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])
326+
327+
let (pfnLoadLibraryW, pfnFreeLibrary) = (aEntryPoints[0], aEntryPoints[1])
325328
let hThread: HANDLE = CreateRemoteThread(
326-
self.process, nil, 0, loadLibraryAddr,
327-
dllPathRemote, 0, nil)
329+
self.process, nil, 0, pfnLoadLibraryW, dllPathRemote, 0, nil)
328330
if hThread == HANDLE(bitPattern: 0) {
329331
print("CreateRemoteThread failed \(GetLastError())")
330332
return
@@ -345,7 +347,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
345347

346348
// The main heap iteration loop.
347349
outer: while true {
348-
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
350+
let wait = WaitForSingleObject(hReadEvent, WAIT_TIMEOUT_MS)
349351
if wait != WAIT_OBJECT_0 {
350352
print("WaitForSingleObject failed \(wait)")
351353
return
@@ -381,7 +383,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
381383
}
382384
}
383385

384-
let wait = WaitForSingleObject(hThread, waitTimeoutMs)
386+
let wait = WaitForSingleObject(hThread, WAIT_TIMEOUT_MS)
385387
if wait != WAIT_OBJECT_0 {
386388
print("WaitForSingleObject on LoadLibrary failed \(wait)")
387389
return
@@ -438,43 +440,44 @@ internal final class WindowsRemoteProcess: RemoteProcess {
438440
///
439441
/// Performs the necessary clean up to remove the injected code from the
440442
/// instrumented process once the heap walk is complete.
441-
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
442-
from dwProcessId: DWORD,
443-
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
444-
// Get the dll module handle in the remote process to use it to
445-
// unload it below.
446-
447-
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
448-
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
449-
// So, search for it using the snapshot instead.
450-
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
443+
private func eject(module pwszModule: UnsafeMutableRawPointer,
444+
from process: DWORD,
445+
_ pfnFunction: LPTHREAD_START_ROUTINE) -> Bool {
446+
// Deallocate the dll path string in the remote process
447+
if !VirtualFreeEx(self.process, pwszModule, 0, DWORD(MEM_RELEASE)) {
448+
print("VirtualFreeEx failed: \(GetLastError())")
449+
}
450+
451+
// Get the dll module handle in the remote process to use it to unload it
452+
// below. `GetExitCodeThread` returns a `DWORD` (32-bit) but the `HMODULE`
453+
// pointer-sized and may be truncated, so, search for it using the snapshot
454+
// instead.
455+
guard let hModule = find(module: "SwiftInspectClient.dll", in: process) else {
451456
print("Failed to find the client dll")
452457
return false
453458
}
459+
454460
// Unload the dll from the remote process
455-
let hUnloadThread = CreateRemoteThread(
456-
self.process, nil, 0, freeLibraryAddr,
457-
UnsafeMutableRawPointer(hModule), 0, nil)
458-
if hUnloadThread == HANDLE(bitPattern: 0) {
461+
guard let hThread = CreateRemoteThread(self.process, nil, 0, pfnFunction,
462+
hModule, 0, nil) else {
459463
print("CreateRemoteThread for unload failed \(GetLastError())")
460464
return false
461465
}
462-
defer { CloseHandle(hUnloadThread) }
463-
let unload_wait = WaitForSingleObject(hUnloadThread, DWORD(WAIT_TIMEOUT_MS))
464-
if unload_wait != WAIT_OBJECT_0 {
465-
print("WaitForSingleObject on FreeLibrary failed \(unload_wait)")
466+
defer { CloseHandle(hThread) }
467+
468+
guard WaitForSingleObject(hThread, WAIT_TIMEOUT_MS) == WAIT_OBJECT_0 else {
469+
print("WaitForSingleObject on FreeLibrary failed \(GetLastError())")
466470
return false
467471
}
468-
var unloadExitCode: DWORD = 0
469-
GetExitCodeThread(hUnloadThread, &unloadExitCode)
470-
if unloadExitCode == 0 {
471-
print("FreeLibrary failed")
472+
473+
var dwExitCode: DWORD = 1
474+
guard GetExitCodeThread(hThread, &dwExitCode) else {
475+
print("GetExitCodeThread for unload failed \(GetLastError())")
472476
return false
473477
}
474478

475-
// Deallocate the dll path string in the remote process
476-
if !VirtualFreeEx(self.process, dllPathRemote, 0, DWORD(MEM_RELEASE)) {
477-
print("VirtualFreeEx failed GLE=\(GetLastError())")
479+
guard dwExitCode == 0 else {
480+
print("FreeLibrary failed \(dwExitCode)")
478481
return false
479482
}
480483

@@ -516,19 +519,12 @@ internal final class WindowsRemoteProcess: RemoteProcess {
516519
return hModule
517520
}
518521

519-
private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
520-
-> [LPTHREAD_START_ROUTINE]?
521-
{
522-
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
523-
print("Failed to find remote module \(moduleName)")
522+
private func find(module: String, symbols: [String], in process: DWORD) -> [FARPROC]? {
523+
guard let hModule = find(module: module, in: process) else {
524+
print("Failed to find remote module \(module)")
524525
return nil
525526
}
526-
var addresses: [LPTHREAD_START_ROUTINE] = []
527-
for sym in symbols {
528-
addresses.append(
529-
unsafeBitCast(GetProcAddress(hDllModule, sym), to: LPTHREAD_START_ROUTINE.self))
530-
}
531-
return addresses
527+
return symbols.map { GetProcAddress(hModule, $0) }
532528
}
533529

534530
private func createEventPair(_ dwProcessId: DWORD) -> (HANDLE, HANDLE)? {

0 commit comments

Comments
 (0)