Skip to content

Commit 91691a8

Browse files
committed
f - Reconnect and retry HTTP requests
1 parent f9cb75a commit 91691a8

File tree

3 files changed

+88
-48
lines changed

3 files changed

+88
-48
lines changed

lightning-block-sync/src/http.rs

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,8 @@ impl HttpClient {
140140
Host: {}\r\n\
141141
Connection: keep-alive\r\n\
142142
\r\n", uri, host);
143-
self.write_request(request).await?;
144-
let bytes = self.read_response().await?;
145-
F::try_from(bytes)
143+
let response_body = self.send_request_with_retry(&request).await?;
144+
F::try_from(response_body)
146145
}
147146

148147
/// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP
@@ -162,13 +161,37 @@ impl HttpClient {
162161
Content-Length: {}\r\n\
163162
\r\n\
164163
{}", uri, host, auth, content.len(), content);
164+
let response_body = self.send_request_with_retry(&request).await?;
165+
F::try_from(response_body)
166+
}
167+
168+
/// Sends an HTTP request message and reads the response, returning its body. Attempts to
169+
/// reconnect and retry if the connection has been closed.
170+
async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
171+
let endpoint = self.stream.peer_addr().unwrap();
172+
match self.send_request(request).await {
173+
Ok(bytes) => Ok(bytes),
174+
Err(e) => match e.kind() {
175+
std::io::ErrorKind::ConnectionReset |
176+
std::io::ErrorKind::ConnectionAborted |
177+
std::io::ErrorKind::UnexpectedEof => {
178+
// Reconnect if the connection was closed.
179+
*self = Self::connect(endpoint)?;
180+
self.send_request(request).await
181+
},
182+
_ => Err(e),
183+
},
184+
}
185+
}
186+
187+
/// Sends an HTTP request message and reads the response, returning its body.
188+
async fn send_request(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
165189
self.write_request(request).await?;
166-
let bytes = self.read_response().await?;
167-
F::try_from(bytes)
190+
self.read_response().await
168191
}
169192

170193
/// Writes an HTTP request message.
171-
async fn write_request(&mut self, request: String) -> std::io::Result<()> {
194+
async fn write_request(&mut self, request: &str) -> std::io::Result<()> {
172195
#[cfg(feature = "tokio")]
173196
{
174197
self.stream.write_all(request.as_bytes()).await?;
@@ -214,14 +237,14 @@ impl HttpClient {
214237

215238
// Read and parse status line
216239
let status_line = read_line!()
217-
.ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no status line"))?;
240+
.ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?;
218241
let status = HttpStatus::parse(&status_line)?;
219242

220243
// Read and parse relevant headers
221244
let mut message_length = HttpMessageLength::Empty;
222245
loop {
223246
let line = read_line!()
224-
.ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "unexpected eof"))?;
247+
.ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?;
225248
if line.is_empty() { break; }
226249

227250
let header = HttpHeader::parse(&line)?;
@@ -512,21 +535,23 @@ pub(crate) mod client_tests {
512535
let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
513536
let shutdown_signaled = std::sync::Arc::clone(&shutdown);
514537
let handler = std::thread::spawn(move || {
515-
let (mut stream, _) = listener.accept().unwrap();
516-
stream.set_write_timeout(Some(Duration::from_secs(1))).unwrap();
517-
518-
let lines_read = std::io::BufReader::new(&stream)
519-
.lines()
520-
.take_while(|line| !line.as_ref().unwrap().is_empty())
521-
.count();
522-
if lines_read == 0 { return; }
523-
524-
for chunk in response.as_bytes().chunks(16) {
525-
if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) {
526-
break;
527-
} else {
528-
stream.write(chunk).unwrap();
529-
stream.flush().unwrap();
538+
for stream in listener.incoming() {
539+
let mut stream = stream.unwrap();
540+
stream.set_write_timeout(Some(Duration::from_secs(1))).unwrap();
541+
542+
let lines_read = std::io::BufReader::new(&stream)
543+
.lines()
544+
.take_while(|line| !line.as_ref().unwrap().is_empty())
545+
.count();
546+
if lines_read == 0 { continue; }
547+
548+
for chunk in response.as_bytes().chunks(16) {
549+
if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) {
550+
return;
551+
} else {
552+
stream.write(chunk).unwrap();
553+
stream.flush().unwrap();
554+
}
530555
}
531556
}
532557
});
@@ -587,7 +612,7 @@ pub(crate) mod client_tests {
587612
drop(server);
588613
match client.get::<BinaryResponse>("/foo", "foo.com").await {
589614
Err(e) => {
590-
assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
615+
assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
591616
assert_eq!(e.get_ref().unwrap().to_string(), "no status line");
592617
},
593618
Ok(_) => panic!("Expected error"),
@@ -602,8 +627,8 @@ pub(crate) mod client_tests {
602627
drop(server);
603628
match client.get::<BinaryResponse>("/foo", "foo.com").await {
604629
Err(e) => {
605-
assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
606-
assert_eq!(e.get_ref().unwrap().to_string(), "unexpected eof");
630+
assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
631+
assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
607632
},
608633
Ok(_) => panic!("Expected error"),
609634
}
@@ -620,8 +645,8 @@ pub(crate) mod client_tests {
620645
let mut client = HttpClient::connect(&server.endpoint()).unwrap();
621646
match client.get::<BinaryResponse>("/foo", "foo.com").await {
622647
Err(e) => {
623-
assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
624-
assert_eq!(e.get_ref().unwrap().to_string(), "unexpected eof");
648+
assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
649+
assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
625650
},
626651
Ok(_) => panic!("Expected error"),
627652
}
@@ -699,6 +724,18 @@ pub(crate) mod client_tests {
699724
}
700725
}
701726

727+
#[tokio::test]
728+
async fn reconnect_closed_connection() {
729+
let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
730+
731+
let mut client = HttpClient::connect(&server.endpoint()).unwrap();
732+
assert!(client.get::<BinaryResponse>("/foo", "foo.com").await.is_ok());
733+
match client.get::<BinaryResponse>("/foo", "foo.com").await {
734+
Err(e) => panic!("Unexpected error: {:?}", e),
735+
Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
736+
}
737+
}
738+
702739
#[test]
703740
fn from_bytes_into_binary_response() {
704741
let bytes = b"foo";

lightning-block-sync/src/rest.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@ use std::convert::TryInto;
66
/// A simple REST client for requesting resources using HTTP `GET`.
77
pub struct RESTClient {
88
endpoint: HttpEndpoint,
9+
client: HttpClient,
910
}
1011

1112
impl RESTClient {
12-
pub fn new(endpoint: HttpEndpoint) -> Self {
13-
Self { endpoint }
13+
pub fn new(endpoint: HttpEndpoint) -> std::io::Result<Self> {
14+
let client = HttpClient::connect(&endpoint)?;
15+
Ok(Self { endpoint, client })
1416
}
1517

1618
/// Requests a resource encoded in `F` format and interpreted as type `T`.
17-
async fn request_resource<F, T>(&self, resource_path: &str) -> std::io::Result<T>
19+
async fn request_resource<F, T>(&mut self, resource_path: &str) -> std::io::Result<T>
1820
where F: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
1921
let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
2022
let uri = format!("{}/{}", self.endpoint.path().trim_end_matches("/"), resource_path);
21-
22-
let mut client = HttpClient::connect(&self.endpoint)?;
23-
client.get::<F>(&uri, &host).await?.try_into()
23+
self.client.get::<F>(&uri, &host).await?.try_into()
2424
}
2525
}
2626

@@ -48,7 +48,7 @@ mod tests {
4848
#[tokio::test]
4949
async fn request_unknown_resource() {
5050
let server = HttpServer::responding_with_not_found();
51-
let client = RESTClient::new(server.endpoint());
51+
let mut client = RESTClient::new(server.endpoint()).unwrap();
5252

5353
match client.request_resource::<BinaryResponse, u32>("/").await {
5454
Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound),
@@ -59,7 +59,7 @@ mod tests {
5959
#[tokio::test]
6060
async fn request_malformed_resource() {
6161
let server = HttpServer::responding_with_ok(MessageBody::Content("foo"));
62-
let client = RESTClient::new(server.endpoint());
62+
let mut client = RESTClient::new(server.endpoint()).unwrap();
6363

6464
match client.request_resource::<BinaryResponse, u32>("/").await {
6565
Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidData),
@@ -70,7 +70,7 @@ mod tests {
7070
#[tokio::test]
7171
async fn request_valid_resource() {
7272
let server = HttpServer::responding_with_ok(MessageBody::Content(42));
73-
let client = RESTClient::new(server.endpoint());
73+
let mut client = RESTClient::new(server.endpoint()).unwrap();
7474

7575
match client.request_resource::<BinaryResponse, u32>("/").await {
7676
Err(e) => panic!("Unexpected error: {:?}", e),

lightning-block-sync/src/rpc.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,23 @@ use std::sync::atomic::{AtomicUsize, Ordering};
1111
pub struct RPCClient {
1212
basic_auth: String,
1313
endpoint: HttpEndpoint,
14+
client: HttpClient,
1415
id: AtomicUsize,
1516
}
1617

1718
impl RPCClient {
18-
pub fn new(user_auth: &str, endpoint: HttpEndpoint) -> Self {
19-
Self {
19+
pub fn new(user_auth: &str, endpoint: HttpEndpoint) -> std::io::Result<Self> {
20+
let client = HttpClient::connect(&endpoint)?;
21+
Ok(Self {
2022
basic_auth: "Basic ".to_string() + &base64::encode(user_auth),
2123
endpoint,
24+
client,
2225
id: AtomicUsize::new(0),
23-
}
26+
})
2427
}
2528

2629
/// Calls a method with the response encoded in JSON format and interpreted as type `T`.
27-
async fn call_method<T>(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
30+
async fn call_method<T>(&mut self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
2831
where JsonResponse: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
2932
let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
3033
let uri = self.endpoint.path();
@@ -34,8 +37,8 @@ impl RPCClient {
3437
"id": &self.id.fetch_add(1, Ordering::AcqRel).to_string()
3538
});
3639

37-
let mut client = HttpClient::connect(&self.endpoint)?;
38-
let mut response = client.post::<JsonResponse>(&uri, &host, &self.basic_auth, content).await?.0;
40+
let mut response = self.client.post::<JsonResponse>(&uri, &host, &self.basic_auth, content)
41+
.await?.0;
3942
if !response.is_object() {
4043
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object"));
4144
}
@@ -76,7 +79,7 @@ mod tests {
7679
#[tokio::test]
7780
async fn call_method_returning_unknown_response() {
7881
let server = HttpServer::responding_with_not_found();
79-
let client = RPCClient::new("credentials", server.endpoint());
82+
let mut client = RPCClient::new("credentials", server.endpoint()).unwrap();
8083

8184
match client.call_method::<u64>("getblockcount", &[]).await {
8285
Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound),
@@ -88,7 +91,7 @@ mod tests {
8891
async fn call_method_returning_malfomred_response() {
8992
let response = serde_json::json!("foo");
9093
let server = HttpServer::responding_with_ok(MessageBody::Content(response));
91-
let client = RPCClient::new("credentials", server.endpoint());
94+
let mut client = RPCClient::new("credentials", server.endpoint()).unwrap();
9295

9396
match client.call_method::<u64>("getblockcount", &[]).await {
9497
Err(e) => {
@@ -105,7 +108,7 @@ mod tests {
105108
"error": { "code": -8, "message": "invalid parameter" },
106109
});
107110
let server = HttpServer::responding_with_ok(MessageBody::Content(response));
108-
let client = RPCClient::new("credentials", server.endpoint());
111+
let mut client = RPCClient::new("credentials", server.endpoint()).unwrap();
109112

110113
let invalid_block_hash = serde_json::json!("foo");
111114
match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
@@ -121,7 +124,7 @@ mod tests {
121124
async fn call_method_returning_missing_result() {
122125
let response = serde_json::json!({ "result": null });
123126
let server = HttpServer::responding_with_ok(MessageBody::Content(response));
124-
let client = RPCClient::new("credentials", server.endpoint());
127+
let mut client = RPCClient::new("credentials", server.endpoint()).unwrap();
125128

126129
match client.call_method::<u64>("getblockcount", &[]).await {
127130
Err(e) => {
@@ -136,7 +139,7 @@ mod tests {
136139
async fn call_method_returning_valid_result() {
137140
let response = serde_json::json!({ "result": 654470 });
138141
let server = HttpServer::responding_with_ok(MessageBody::Content(response));
139-
let client = RPCClient::new("credentials", server.endpoint());
142+
let mut client = RPCClient::new("credentials", server.endpoint()).unwrap();
140143

141144
match client.call_method::<u64>("getblockcount", &[]).await {
142145
Err(e) => panic!("Unexpected error: {:?}", e),

0 commit comments

Comments
 (0)