Skip to content

Commit 18fa460

Browse files
authored
Merge pull request #70103 from compnerd/remote
swift-inspect: clean up some of the code in `WindowsRemoteProcess`
2 parents 4c754ae + 7a6c6c7 commit 18fa460

File tree

1 file changed

+44
-48
lines changed

1 file changed

+44
-48
lines changed

tools/swift-inspect/Sources/swift-inspect/WindowsRemoteProcess.swift

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ import SwiftRemoteMirror
1717
import Foundation
1818
import SwiftInspectClientInterface
1919

20+
internal var WAIT_TIMEOUT_MS: DWORD {
21+
DWORD(SwiftInspectClientInterface.WAIT_TIMEOUT_MS)
22+
}
23+
2024
internal final class WindowsRemoteProcess: RemoteProcess {
2125
public typealias ProcessIdentifier = DWORD
2226
public typealias ProcessHandle = HANDLE
@@ -271,7 +275,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
271275
// memory and event objects
272276
let bufSize = Int(BUF_SIZE)
273277
let sharedMemoryName = "\(SHARED_MEM_NAME_PREFIX)-\(String(dwProcessId))"
274-
let waitTimeoutMs = DWORD(WAIT_TIMEOUT_MS)
275278

276279
// Set up the shared memory
277280
let hMapFile = CreateFileMappingA(
@@ -312,19 +315,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
312315
return
313316
}
314317

315-
// Load the dll and start the heap walk
316-
guard
317-
let remoteAddrs = findRemoteAddresses(
318-
dwProcessId: dwProcessId, moduleName: "KERNEL32.DLL",
319-
symbols: ["LoadLibraryW", "FreeLibrary"])
320-
else {
318+
guard let aEntryPoints = find(module: "KERNEL32.DLL",
319+
symbols: ["LoadLibraryW", "FreeLibrary"],
320+
in: dwProcessId)?.map({
321+
unsafeBitCast($0, to: LPTHREAD_START_ROUTINE.self)
322+
}) else {
321323
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
322324
return
323325
}
324-
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])
326+
327+
let (pfnLoadLibraryW, pfnFreeLibrary) = (aEntryPoints[0], aEntryPoints[1])
325328
let hThread: HANDLE = CreateRemoteThread(
326-
self.process, nil, 0, loadLibraryAddr,
327-
dllPathRemote, 0, nil)
329+
self.process, nil, 0, pfnLoadLibraryW, dllPathRemote, 0, nil)
328330
if hThread == HANDLE(bitPattern: 0) {
329331
print("CreateRemoteThread failed \(GetLastError())")
330332
return
@@ -345,7 +347,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
345347

346348
// The main heap iteration loop.
347349
outer: while true {
348-
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
350+
let wait = WaitForSingleObject(hReadEvent, WAIT_TIMEOUT_MS)
349351
if wait != WAIT_OBJECT_0 {
350352
print("WaitForSingleObject failed \(wait)")
351353
return
@@ -381,7 +383,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
381383
}
382384
}
383385

384-
let wait = WaitForSingleObject(hThread, waitTimeoutMs)
386+
let wait = WaitForSingleObject(hThread, WAIT_TIMEOUT_MS)
385387
if wait != WAIT_OBJECT_0 {
386388
print("WaitForSingleObject on LoadLibrary failed \(wait)")
387389
return
@@ -438,43 +440,44 @@ internal final class WindowsRemoteProcess: RemoteProcess {
438440
///
439441
/// Performs the necessary clean up to remove the injected code from the
440442
/// instrumented process once the heap walk is complete.
441-
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
442-
from dwProcessId: DWORD,
443-
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
444-
// Get the dll module handle in the remote process to use it to
445-
// unload it below.
446-
447-
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
448-
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
449-
// So, search for it using the snapshot instead.
450-
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
443+
private func eject(module pwszModule: UnsafeMutableRawPointer,
444+
from process: DWORD,
445+
_ pfnFunction: LPTHREAD_START_ROUTINE) -> Bool {
446+
// Deallocate the dll path string in the remote process
447+
if !VirtualFreeEx(self.process, pwszModule, 0, DWORD(MEM_RELEASE)) {
448+
print("VirtualFreeEx failed: \(GetLastError())")
449+
}
450+
451+
// Get the dll module handle in the remote process to use it to unload it
452+
// below. `GetExitCodeThread` returns a `DWORD` (32-bit) but the `HMODULE`
453+
// pointer-sized and may be truncated, so, search for it using the snapshot
454+
// instead.
455+
guard let hModule = find(module: "SwiftInspectClient.dll", in: process) else {
451456
print("Failed to find the client dll")
452457
return false
453458
}
459+
454460
// Unload the dll from the remote process
455-
let hUnloadThread = CreateRemoteThread(
456-
self.process, nil, 0, freeLibraryAddr,
457-
UnsafeMutableRawPointer(hModule), 0, nil)
458-
if hUnloadThread == HANDLE(bitPattern: 0) {
461+
guard let hThread = CreateRemoteThread(self.process, nil, 0, pfnFunction,
462+
hModule, 0, nil) else {
459463
print("CreateRemoteThread for unload failed \(GetLastError())")
460464
return false
461465
}
462-
defer { CloseHandle(hUnloadThread) }
463-
let unload_wait = WaitForSingleObject(hUnloadThread, DWORD(WAIT_TIMEOUT_MS))
464-
if unload_wait != WAIT_OBJECT_0 {
465-
print("WaitForSingleObject on FreeLibrary failed \(unload_wait)")
466+
defer { CloseHandle(hThread) }
467+
468+
guard WaitForSingleObject(hThread, WAIT_TIMEOUT_MS) == WAIT_OBJECT_0 else {
469+
print("WaitForSingleObject on FreeLibrary failed \(GetLastError())")
466470
return false
467471
}
468-
var unloadExitCode: DWORD = 0
469-
GetExitCodeThread(hUnloadThread, &unloadExitCode)
470-
if unloadExitCode == 0 {
471-
print("FreeLibrary failed")
472+
473+
var dwExitCode: DWORD = 1
474+
guard GetExitCodeThread(hThread, &dwExitCode) else {
475+
print("GetExitCodeThread for unload failed \(GetLastError())")
472476
return false
473477
}
474478

475-
// Deallocate the dll path string in the remote process
476-
if !VirtualFreeEx(self.process, dllPathRemote, 0, DWORD(MEM_RELEASE)) {
477-
print("VirtualFreeEx failed GLE=\(GetLastError())")
479+
guard dwExitCode == 0 else {
480+
print("FreeLibrary failed \(dwExitCode)")
478481
return false
479482
}
480483

@@ -516,19 +519,12 @@ internal final class WindowsRemoteProcess: RemoteProcess {
516519
return hModule
517520
}
518521

519-
private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
520-
-> [LPTHREAD_START_ROUTINE]?
521-
{
522-
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
523-
print("Failed to find remote module \(moduleName)")
522+
private func find(module: String, symbols: [String], in process: DWORD) -> [FARPROC]? {
523+
guard let hModule = find(module: module, in: process) else {
524+
print("Failed to find remote module \(module)")
524525
return nil
525526
}
526-
var addresses: [LPTHREAD_START_ROUTINE] = []
527-
for sym in symbols {
528-
addresses.append(
529-
unsafeBitCast(GetProcAddress(hDllModule, sym), to: LPTHREAD_START_ROUTINE.self))
530-
}
531-
return addresses
527+
return symbols.map { GetProcAddress(hModule, $0) }
532528
}
533529

534530
private func createEventPair(_ dwProcessId: DWORD) -> (HANDLE, HANDLE)? {

0 commit comments

Comments
 (0)