Skip to content

swift-inspect: clean up some of the code in WindowsRemoteProcess #70103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import SwiftRemoteMirror
import Foundation
import SwiftInspectClientInterface

internal var WAIT_TIMEOUT_MS: DWORD {
DWORD(SwiftInspectClientInterface.WAIT_TIMEOUT_MS)
}

internal final class WindowsRemoteProcess: RemoteProcess {
public typealias ProcessIdentifier = DWORD
public typealias ProcessHandle = HANDLE
Expand Down Expand Up @@ -271,7 +275,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
// memory and event objects
let bufSize = Int(BUF_SIZE)
let sharedMemoryName = "\(SHARED_MEM_NAME_PREFIX)-\(String(dwProcessId))"
let waitTimeoutMs = DWORD(WAIT_TIMEOUT_MS)

// Set up the shared memory
let hMapFile = CreateFileMappingA(
Expand Down Expand Up @@ -312,19 +315,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
return
}

// Load the dll and start the heap walk
guard
let remoteAddrs = findRemoteAddresses(
dwProcessId: dwProcessId, moduleName: "KERNEL32.DLL",
symbols: ["LoadLibraryW", "FreeLibrary"])
else {
guard let aEntryPoints = find(module: "KERNEL32.DLL",
symbols: ["LoadLibraryW", "FreeLibrary"],
in: dwProcessId)?.map({
unsafeBitCast($0, to: LPTHREAD_START_ROUTINE.self)
}) else {
print("Failed to find remote LoadLibraryW/FreeLibrary addresses")
return
}
let (loadLibraryAddr, pfnFreeLibrary) = (remoteAddrs[0], remoteAddrs[1])

let (pfnLoadLibraryW, pfnFreeLibrary) = (aEntryPoints[0], aEntryPoints[1])
let hThread: HANDLE = CreateRemoteThread(
self.process, nil, 0, loadLibraryAddr,
dllPathRemote, 0, nil)
self.process, nil, 0, pfnLoadLibraryW, dllPathRemote, 0, nil)
if hThread == HANDLE(bitPattern: 0) {
print("CreateRemoteThread failed \(GetLastError())")
return
Expand All @@ -345,7 +347,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {

// The main heap iteration loop.
outer: while true {
let wait = WaitForSingleObject(hReadEvent, waitTimeoutMs)
let wait = WaitForSingleObject(hReadEvent, WAIT_TIMEOUT_MS)
if wait != WAIT_OBJECT_0 {
print("WaitForSingleObject failed \(wait)")
return
Expand Down Expand Up @@ -381,7 +383,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
}
}

let wait = WaitForSingleObject(hThread, waitTimeoutMs)
let wait = WaitForSingleObject(hThread, WAIT_TIMEOUT_MS)
if wait != WAIT_OBJECT_0 {
print("WaitForSingleObject on LoadLibrary failed \(wait)")
return
Expand Down Expand Up @@ -438,43 +440,44 @@ internal final class WindowsRemoteProcess: RemoteProcess {
///
/// Performs the necessary clean up to remove the injected code from the
/// instrumented process once the heap walk is complete.
private func eject(module dllPathRemote: UnsafeMutableRawPointer,
from dwProcessId: DWORD,
_ freeLibraryAddr: LPTHREAD_START_ROUTINE) -> Bool {
// Get the dll module handle in the remote process to use it to
// unload it below.

// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
// So, search for it using the snapshot instead.
guard let hModule = find(module: "SwiftInspectClient.dll", in: dwProcessId) else {
private func eject(module pwszModule: UnsafeMutableRawPointer,
from process: DWORD,
_ pfnFunction: LPTHREAD_START_ROUTINE) -> Bool {
// Deallocate the dll path string in the remote process
if !VirtualFreeEx(self.process, pwszModule, 0, DWORD(MEM_RELEASE)) {
print("VirtualFreeEx failed: \(GetLastError())")
}

// Get the dll module handle in the remote process to use it to unload it
// below. `GetExitCodeThread` returns a `DWORD` (32-bit) but the `HMODULE`
// pointer-sized and may be truncated, so, search for it using the snapshot
// instead.
guard let hModule = find(module: "SwiftInspectClient.dll", in: process) else {
print("Failed to find the client dll")
return false
}

// Unload the dll from the remote process
let hUnloadThread = CreateRemoteThread(
self.process, nil, 0, freeLibraryAddr,
UnsafeMutableRawPointer(hModule), 0, nil)
if hUnloadThread == HANDLE(bitPattern: 0) {
guard let hThread = CreateRemoteThread(self.process, nil, 0, pfnFunction,
hModule, 0, nil) else {
print("CreateRemoteThread for unload failed \(GetLastError())")
return false
}
defer { CloseHandle(hUnloadThread) }
let unload_wait = WaitForSingleObject(hUnloadThread, DWORD(WAIT_TIMEOUT_MS))
if unload_wait != WAIT_OBJECT_0 {
print("WaitForSingleObject on FreeLibrary failed \(unload_wait)")
defer { CloseHandle(hThread) }

guard WaitForSingleObject(hThread, WAIT_TIMEOUT_MS) == WAIT_OBJECT_0 else {
print("WaitForSingleObject on FreeLibrary failed \(GetLastError())")
return false
}
var unloadExitCode: DWORD = 0
GetExitCodeThread(hUnloadThread, &unloadExitCode)
if unloadExitCode == 0 {
print("FreeLibrary failed")

var dwExitCode: DWORD = 1
guard GetExitCodeThread(hThread, &dwExitCode) else {
print("GetExitCodeThread for unload failed \(GetLastError())")
return false
}

// Deallocate the dll path string in the remote process
if !VirtualFreeEx(self.process, dllPathRemote, 0, DWORD(MEM_RELEASE)) {
print("VirtualFreeEx failed GLE=\(GetLastError())")
guard dwExitCode == 0 else {
print("FreeLibrary failed \(dwExitCode)")
return false
}

Expand Down Expand Up @@ -516,19 +519,12 @@ internal final class WindowsRemoteProcess: RemoteProcess {
return hModule
}

private func findRemoteAddresses(dwProcessId: DWORD, moduleName: String, symbols: [String])
-> [LPTHREAD_START_ROUTINE]?
{
guard let hDllModule = find(module: moduleName, in: dwProcessId) else {
print("Failed to find remote module \(moduleName)")
private func find(module: String, symbols: [String], in process: DWORD) -> [FARPROC]? {
guard let hModule = find(module: module, in: process) else {
print("Failed to find remote module \(module)")
return nil
}
var addresses: [LPTHREAD_START_ROUTINE] = []
for sym in symbols {
addresses.append(
unsafeBitCast(GetProcAddress(hDllModule, sym), to: LPTHREAD_START_ROUTINE.self))
}
return addresses
return symbols.map { GetProcAddress(hModule, $0) }
}

private func createEventPair(_ dwProcessId: DWORD) -> (HANDLE, HANDLE)? {
Expand Down