@@ -161,16 +161,13 @@ internal final class WindowsRemoteProcess: RemoteProcess {
161
161
self . context = context
162
162
163
163
// Locate swiftCore.dll in the target process and load modules.
164
- iterateRemoteModules (
165
- dwProcessId: processId,
166
- closure: { ( entry, module) in
167
- // FIXME(compnerd) support static linking at some point
168
- if module == " swiftCore.dll " {
169
- self . hSwiftCore = entry. hModule
170
- }
171
- _ = swift_reflection_addImage (
172
- context, unsafeBitCast ( entry. modBaseAddr, to: swift_addr_t. self) )
173
- } )
164
+ modules ( of: processId) { ( entry, module) in
165
+ // FIXME(compnerd) support static linking at some point
166
+ if module == " swiftCore.dll " { self . hSwiftCore = entry. hModule }
167
+ _ = swift_reflection_addImage ( context,
168
+ unsafeBitCast ( entry. modBaseAddr,
169
+ to: swift_addr_t. self) )
170
+ }
174
171
if self . hSwiftCore == HMODULE ( bitPattern: - 1 ) {
175
172
// FIXME(compnerd) log error
176
173
return nil
@@ -321,7 +318,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
321
318
print ( " Failed to find remote LoadLibraryW/FreeLibrary addresses " )
322
319
return
323
320
}
324
- let ( loadLibraryAddr, freeLibraryAddr ) = ( remoteAddrs [ 0 ] , remoteAddrs [ 1 ] )
321
+ let ( loadLibraryAddr, pfnFreeLibrary ) = ( remoteAddrs [ 0 ] , remoteAddrs [ 1 ] )
325
322
let hThread : HANDLE = CreateRemoteThread (
326
323
self . process, nil , 0 , loadLibraryAddr,
327
324
dllPathRemote, 0 , nil )
@@ -331,6 +328,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
331
328
}
332
329
defer { CloseHandle ( hThread) }
333
330
331
+ defer {
332
+ // Always perform the code ejection process even if the heap walk fails.
333
+ // The module cannot re-execute the heap walk and will leave a retain
334
+ // count behind, preventing the module from being unlinked on the file
335
+ // system as well as leave code in the inspected process. This will
336
+ // eventually be an issue for treating the injected code as a resource
337
+ // which is extracted temporarily.
338
+ if !eject( module: dllPathRemote, from: dwProcessId, pfnFreeLibrary) {
339
+ print ( " Failed to unload the remote dll " )
340
+ }
341
+ }
342
+
334
343
// The main heap iteration loop.
335
344
outer: while true {
336
345
let wait = WaitForSingleObject ( hReadEvent, waitTimeoutMs)
@@ -381,14 +390,6 @@ internal final class WindowsRemoteProcess: RemoteProcess {
381
390
print ( " LoadLibraryW failed \( threadExitCode) " )
382
391
return
383
392
}
384
-
385
- // Unload the dll and deallocate the dll path from the remote process
386
- if !unloadDllAndPathRemote(
387
- dwProcessId: dwProcessId, dllPathRemote: dllPathRemote, freeLibraryAddr: freeLibraryAddr)
388
- {
389
- print ( " Failed to unload the remote dll " )
390
- return
391
- }
392
393
}
393
394
394
395
private func allocateDllPathRemote( ) -> UnsafeMutableRawPointer ? {
@@ -430,26 +431,27 @@ internal final class WindowsRemoteProcess: RemoteProcess {
430
431
}
431
432
}
432
433
433
- private func unloadDllAndPathRemote(
434
- dwProcessId: DWORD , dllPathRemote: UnsafeMutableRawPointer ,
435
- freeLibraryAddr: LPTHREAD_START_ROUTINE
436
- ) -> Bool {
434
+ /// Eject the injected code from the instrumented process.
435
+ ///
436
+ /// Performs the necessary clean up to remove the injected code from the
437
+ /// instrumented process once the heap walk is complete.
438
+ private func eject( module dllPathRemote: UnsafeMutableRawPointer ,
439
+ from dwProcessId: DWORD ,
440
+ _ freeLibraryAddr: LPTHREAD_START_ROUTINE ) -> Bool {
437
441
// Get the dll module handle in the remote process to use it to
438
442
// unload it below.
443
+
439
444
// GetExitCodeThread returns a DWORD (32-bit) but the HMODULE
440
445
// returned from LoadLibraryW is a 64-bit pointer and may be truncated.
441
446
// So, search for it using the snapshot instead.
442
- guard
443
- let hDllModule = findRemoteModule (
444
- dwProcessId: dwProcessId, moduleName: " SwiftInspectClient.dll " )
445
- else {
447
+ guard let hModule = find ( module: " SwiftInspectClient.dll " , in: dwProcessId) else {
446
448
print ( " Failed to find the client dll " )
447
449
return false
448
450
}
449
451
// Unload the dll from the remote process
450
452
let hUnloadThread = CreateRemoteThread (
451
453
self . process, nil , 0 , freeLibraryAddr,
452
- UnsafeMutableRawPointer ( hDllModule ) , 0 , nil )
454
+ UnsafeMutableRawPointer ( hModule ) , 0 , nil )
453
455
if hUnloadThread == HANDLE ( bitPattern: 0 ) {
454
456
print ( " CreateRemoteThread for unload failed \( GetLastError ( ) ) " )
455
457
return false
@@ -476,7 +478,7 @@ internal final class WindowsRemoteProcess: RemoteProcess {
476
478
return true
477
479
}
478
480
479
- private func iterateRemoteModules ( dwProcessId: DWORD , closure: ( MODULEENTRY32W , String ) -> Void ) {
481
+ private func modules ( of dwProcessId: DWORD , _ closure: ( MODULEENTRY32W , String ) -> Void ) {
480
482
let hModuleSnapshot : HANDLE =
481
483
CreateToolhelp32Snapshot ( DWORD ( TH32CS_SNAPMODULE) , dwProcessId)
482
484
if hModuleSnapshot == INVALID_HANDLE_VALUE {
@@ -503,22 +505,18 @@ internal final class WindowsRemoteProcess: RemoteProcess {
503
505
} while Module32NextW ( hModuleSnapshot, & entry)
504
506
}
505
507
506
- private func findRemoteModule( dwProcessId: DWORD , moduleName: String ) -> HMODULE ? {
507
- var hDllModule : HMODULE ? = nil
508
- iterateRemoteModules (
509
- dwProcessId: dwProcessId,
510
- closure: { ( entry, module) in
511
- if module == moduleName {
512
- hDllModule = entry. hModule
513
- }
514
- } )
515
- return hDllModule
508
+ private func find( module named: String , in dwProcessId: DWORD ) -> HMODULE ? {
509
+ var hModule : HMODULE ?
510
+ modules ( of: dwProcessId) { ( entry, module) in
511
+ if module == named { hModule = entry. hModule }
512
+ }
513
+ return hModule
516
514
}
517
515
518
516
private func findRemoteAddresses( dwProcessId: DWORD , moduleName: String , symbols: [ String ] )
519
517
-> [ LPTHREAD_START_ROUTINE ] ?
520
518
{
521
- guard let hDllModule = findRemoteModule ( dwProcessId : dwProcessId , moduleName : moduleName ) else {
519
+ guard let hDllModule = find ( module : moduleName , in : dwProcessId ) else {
522
520
print ( " Failed to find remote module \( moduleName) " )
523
521
return nil
524
522
}
0 commit comments