@@ -140,9 +140,8 @@ impl HttpClient {
140
140
Host: {}\r \n \
141
141
Connection: keep-alive\r \n \
142
142
\r \n ", uri, host) ;
143
- self . write_request ( request) . await ?;
144
- let bytes = self . read_response ( ) . await ?;
145
- F :: try_from ( bytes)
143
+ let response_body = self . send_request_with_retry ( & request) . await ?;
144
+ F :: try_from ( response_body)
146
145
}
147
146
148
147
/// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP
@@ -162,13 +161,37 @@ impl HttpClient {
162
161
Content-Length: {}\r \n \
163
162
\r \n \
164
163
{}", uri, host, auth, content. len( ) , content) ;
164
+ let response_body = self . send_request_with_retry ( & request) . await ?;
165
+ F :: try_from ( response_body)
166
+ }
167
+
168
+ /// Sends an HTTP request message and reads the response, returning its body. Attempts to
169
+ /// reconnect and retry if the connection has been closed.
170
+ async fn send_request_with_retry ( & mut self , request : & str ) -> std:: io:: Result < Vec < u8 > > {
171
+ let endpoint = self . stream . peer_addr ( ) . unwrap ( ) ;
172
+ match self . send_request ( request) . await {
173
+ Ok ( bytes) => Ok ( bytes) ,
174
+ Err ( e) => match e. kind ( ) {
175
+ std:: io:: ErrorKind :: ConnectionReset |
176
+ std:: io:: ErrorKind :: ConnectionAborted |
177
+ std:: io:: ErrorKind :: UnexpectedEof => {
178
+ // Reconnect if the connection was closed.
179
+ * self = Self :: connect ( endpoint) ?;
180
+ self . send_request ( request) . await
181
+ } ,
182
+ _ => Err ( e) ,
183
+ } ,
184
+ }
185
+ }
186
+
187
+ /// Sends an HTTP request message and reads the response, returning its body.
188
+ async fn send_request ( & mut self , request : & str ) -> std:: io:: Result < Vec < u8 > > {
165
189
self . write_request ( request) . await ?;
166
- let bytes = self . read_response ( ) . await ?;
167
- F :: try_from ( bytes)
190
+ self . read_response ( ) . await
168
191
}
169
192
170
193
/// Writes an HTTP request message.
171
- async fn write_request ( & mut self , request : String ) -> std:: io:: Result < ( ) > {
194
+ async fn write_request ( & mut self , request : & str ) -> std:: io:: Result < ( ) > {
172
195
#[ cfg( feature = "tokio" ) ]
173
196
{
174
197
self . stream . write_all ( request. as_bytes ( ) ) . await ?;
@@ -214,14 +237,14 @@ impl HttpClient {
214
237
215
238
// Read and parse status line
216
239
let status_line = read_line ! ( )
217
- . ok_or ( std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , "no status line" ) ) ?;
240
+ . ok_or ( std:: io:: Error :: new ( std:: io:: ErrorKind :: UnexpectedEof , "no status line" ) ) ?;
218
241
let status = HttpStatus :: parse ( & status_line) ?;
219
242
220
243
// Read and parse relevant headers
221
244
let mut message_length = HttpMessageLength :: Empty ;
222
245
loop {
223
246
let line = read_line ! ( )
224
- . ok_or ( std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , "unexpected eof " ) ) ?;
247
+ . ok_or ( std:: io:: Error :: new ( std:: io:: ErrorKind :: UnexpectedEof , "no headers " ) ) ?;
225
248
if line. is_empty ( ) { break ; }
226
249
227
250
let header = HttpHeader :: parse ( & line) ?;
@@ -512,21 +535,23 @@ pub(crate) mod client_tests {
512
535
let shutdown = std:: sync:: Arc :: new ( std:: sync:: atomic:: AtomicBool :: new ( false ) ) ;
513
536
let shutdown_signaled = std:: sync:: Arc :: clone ( & shutdown) ;
514
537
let handler = std:: thread:: spawn ( move || {
515
- let ( mut stream, _) = listener. accept ( ) . unwrap ( ) ;
516
- stream. set_write_timeout ( Some ( Duration :: from_secs ( 1 ) ) ) . unwrap ( ) ;
517
-
518
- let lines_read = std:: io:: BufReader :: new ( & stream)
519
- . lines ( )
520
- . take_while ( |line| !line. as_ref ( ) . unwrap ( ) . is_empty ( ) )
521
- . count ( ) ;
522
- if lines_read == 0 { return ; }
523
-
524
- for chunk in response. as_bytes ( ) . chunks ( 16 ) {
525
- if shutdown_signaled. load ( std:: sync:: atomic:: Ordering :: SeqCst ) {
526
- break ;
527
- } else {
528
- stream. write ( chunk) . unwrap ( ) ;
529
- stream. flush ( ) . unwrap ( ) ;
538
+ for stream in listener. incoming ( ) {
539
+ let mut stream = stream. unwrap ( ) ;
540
+ stream. set_write_timeout ( Some ( Duration :: from_secs ( 1 ) ) ) . unwrap ( ) ;
541
+
542
+ let lines_read = std:: io:: BufReader :: new ( & stream)
543
+ . lines ( )
544
+ . take_while ( |line| !line. as_ref ( ) . unwrap ( ) . is_empty ( ) )
545
+ . count ( ) ;
546
+ if lines_read == 0 { continue ; }
547
+
548
+ for chunk in response. as_bytes ( ) . chunks ( 16 ) {
549
+ if shutdown_signaled. load ( std:: sync:: atomic:: Ordering :: SeqCst ) {
550
+ return ;
551
+ } else {
552
+ stream. write ( chunk) . unwrap ( ) ;
553
+ stream. flush ( ) . unwrap ( ) ;
554
+ }
530
555
}
531
556
}
532
557
} ) ;
@@ -587,7 +612,7 @@ pub(crate) mod client_tests {
587
612
drop ( server) ;
588
613
match client. get :: < BinaryResponse > ( "/foo" , "foo.com" ) . await {
589
614
Err ( e) => {
590
- assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: InvalidData ) ;
615
+ assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: UnexpectedEof ) ;
591
616
assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "no status line" ) ;
592
617
} ,
593
618
Ok ( _) => panic ! ( "Expected error" ) ,
@@ -602,8 +627,8 @@ pub(crate) mod client_tests {
602
627
drop ( server) ;
603
628
match client. get :: < BinaryResponse > ( "/foo" , "foo.com" ) . await {
604
629
Err ( e) => {
605
- assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: InvalidData ) ;
606
- assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "unexpected eof " ) ;
630
+ assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: UnexpectedEof ) ;
631
+ assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "no headers " ) ;
607
632
} ,
608
633
Ok ( _) => panic ! ( "Expected error" ) ,
609
634
}
@@ -620,8 +645,8 @@ pub(crate) mod client_tests {
620
645
let mut client = HttpClient :: connect ( & server. endpoint ( ) ) . unwrap ( ) ;
621
646
match client. get :: < BinaryResponse > ( "/foo" , "foo.com" ) . await {
622
647
Err ( e) => {
623
- assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: InvalidData ) ;
624
- assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "unexpected eof " ) ;
648
+ assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: UnexpectedEof ) ;
649
+ assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "no headers " ) ;
625
650
} ,
626
651
Ok ( _) => panic ! ( "Expected error" ) ,
627
652
}
@@ -699,6 +724,18 @@ pub(crate) mod client_tests {
699
724
}
700
725
}
701
726
727
+ #[ tokio:: test]
728
+ async fn reconnect_closed_connection ( ) {
729
+ let server = HttpServer :: responding_with_ok :: < String > ( MessageBody :: Empty ) ;
730
+
731
+ let mut client = HttpClient :: connect ( & server. endpoint ( ) ) . unwrap ( ) ;
732
+ assert ! ( client. get:: <BinaryResponse >( "/foo" , "foo.com" ) . await . is_ok( ) ) ;
733
+ match client. get :: < BinaryResponse > ( "/foo" , "foo.com" ) . await {
734
+ Err ( e) => panic ! ( "Unexpected error: {:?}" , e) ,
735
+ Ok ( bytes) => assert_eq ! ( bytes. 0 , Vec :: <u8 >:: new( ) ) ,
736
+ }
737
+ }
738
+
702
739
#[ test]
703
740
fn from_bytes_into_binary_response ( ) {
704
741
let bytes = b"foo" ;
0 commit comments