@@ -70,7 +70,7 @@ use lightning::ln::peer_handler;
70
70
use lightning:: ln:: peer_handler:: SocketDescriptor as LnSocketTrait ;
71
71
use lightning:: ln:: msgs:: ChannelMessageHandler ;
72
72
73
- use std:: task;
73
+ use std:: { task, thread } ;
74
74
use std:: net:: SocketAddr ;
75
75
use std:: sync:: { Arc , Mutex } ;
76
76
use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
@@ -111,6 +111,11 @@ struct Connection {
111
111
// socket. To wake it up (without otherwise changing its state, we can push a value into this
112
112
// Sender.
113
113
read_waker : mpsc:: Sender < ( ) > ,
114
+ // When we are told by rust-lightning to disconnect, we can't return to rust-lightning until we
115
+ // are sure we won't call any more read/write PeerManager functions with the same connection.
116
+ // This is set to true if we're in such a condition (with disconnect checked before with the
117
+ // top-level mutex held) and false when we can return.
118
+ block_disconnect_socket : bool ,
114
119
read_paused : bool ,
115
120
disconnect_state : DisconnectionState ,
116
121
id : u64 ,
@@ -128,17 +133,26 @@ impl Connection {
128
133
} }
129
134
}
130
135
136
+ macro_rules! prepare_read_write_call {
137
+ ( ) => { {
138
+ let mut us_lock = us. lock( ) . unwrap( ) ;
139
+ if us_lock. disconnect_state == DisconnectionState :: RLTriggeredDisconnect {
140
+ shutdown_socket!( "disconnect_socket() call from RL" ) ;
141
+ }
142
+ us_lock. block_disconnect_socket = true ;
143
+ } }
144
+ }
145
+
131
146
let read_paused = us. lock ( ) . unwrap ( ) . read_paused ;
132
147
tokio:: select! {
133
148
v = write_avail_receiver. recv( ) => {
134
149
assert!( v. is_some( ) ) ; // We can't have dropped the sending end, its in the us Arc!
135
- if us. lock( ) . unwrap( ) . disconnect_state == DisconnectionState :: RLTriggeredDisconnect {
136
- shutdown_socket!( "disconnect_socket() call from RL" ) ;
137
- }
150
+ prepare_read_write_call!( ) ;
138
151
if let Err ( e) = peer_manager. write_buffer_space_avail( & mut SocketDescriptor :: new( us. clone( ) ) ) {
139
152
us. lock( ) . unwrap( ) . disconnect_state = DisconnectionState :: RLTriggeredDisconnect ;
140
153
shutdown_socket!( e) ;
141
154
}
155
+ us. lock( ) . unwrap( ) . block_disconnect_socket = false ;
142
156
} ,
143
157
_ = read_wake_receiver. recv( ) => { } ,
144
158
read = reader. read( & mut buf) , if !read_paused => match read {
@@ -147,9 +161,7 @@ impl Connection {
147
161
break ;
148
162
} ,
149
163
Ok ( len) => {
150
- if us. lock( ) . unwrap( ) . disconnect_state == DisconnectionState :: RLTriggeredDisconnect {
151
- shutdown_socket!( "disconnect_socket() call from RL" ) ;
152
- }
164
+ prepare_read_write_call!( ) ;
153
165
match peer_manager. read_event( & mut SocketDescriptor :: new( Arc :: clone( & us) ) , & buf[ 0 ..len] ) {
154
166
Ok ( pause_read) => {
155
167
if pause_read {
@@ -171,6 +183,7 @@ impl Connection {
171
183
shutdown_socket!( e)
172
184
} ,
173
185
}
186
+ us. lock( ) . unwrap( ) . block_disconnect_socket = false ;
174
187
} ,
175
188
Err ( e) => {
176
189
println!( "Connection closed: {}" , e) ;
@@ -179,6 +192,7 @@ impl Connection {
179
192
} ,
180
193
}
181
194
}
195
+ us. lock ( ) . unwrap ( ) . block_disconnect_socket = false ;
182
196
let writer_option = us. lock ( ) . unwrap ( ) . writer . take ( ) ;
183
197
if let Some ( mut writer) = writer_option {
184
198
// If the socket is already closed, shutdown() will fail, so just ignore it.
@@ -212,8 +226,8 @@ impl Connection {
212
226
213
227
( reader, write_receiver, read_receiver,
214
228
Arc :: new ( Mutex :: new ( Self {
215
- writer : Some ( writer) , event_notify, write_avail, read_waker,
216
- read_paused : false , disconnect_state : DisconnectionState :: NeedDisconnectEvent ,
229
+ writer : Some ( writer) , event_notify, write_avail, read_waker, read_paused : false ,
230
+ block_disconnect_socket : false , disconnect_state : DisconnectionState :: NeedDisconnectEvent ,
217
231
id : ID_COUNTER . fetch_add ( 1 , Ordering :: AcqRel )
218
232
} ) ) )
219
233
}
@@ -400,15 +414,18 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
400
414
}
401
415
402
416
fn disconnect_socket ( & mut self ) {
403
- let mut us = self . conn . lock ( ) . unwrap ( ) ;
404
- us. disconnect_state = DisconnectionState :: RLTriggeredDisconnect ;
405
- us. read_paused = true ;
406
- // Wake up the sending thread, assuming it is still alive
407
- let _ = us. write_avail . try_send ( ( ) ) ;
408
- // TODO: There's a race where we don't meet the requirements of disconnect_socket if the
409
- // read task is about to call a PeerManager function (eg read_event or write_event).
410
- // Ideally we need to release the us lock and block until we have confirmation from the
411
- // read task that it has broken out of its main loop.
417
+ {
418
+ let mut us = self . conn . lock ( ) . unwrap ( ) ;
419
+ us. disconnect_state = DisconnectionState :: RLTriggeredDisconnect ;
420
+ us. read_paused = true ;
421
+ // Wake up the sending thread, assuming it is still alive
422
+ let _ = us. write_avail . try_send ( ( ) ) ;
423
+ // Happy-path return:
424
+ if !us. block_disconnect_socket { return ; }
425
+ }
426
+ while self . conn . lock ( ) . unwrap ( ) . block_disconnect_socket {
427
+ thread:: yield_now ( ) ;
428
+ }
412
429
}
413
430
}
414
431
impl Clone for SocketDescriptor {
0 commit comments