Skip to content

Commit b15470b

Browse files
authored
Extract real_ip middleware (#7432)
1 parent ff9fd96 commit b15470b

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

src/middleware.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod debug;
66
mod ember_html;
77
pub mod log_request;
88
pub mod normalize_path;
9+
pub mod real_ip;
910
mod require_user_agent;
1011
pub mod session;
1112
mod static_or_continue;
@@ -44,6 +45,7 @@ pub fn apply_axum_middleware(state: AppState, router: Router<(), TimeoutBody<Bod
4445
.layer(TimeoutLayer::new(Duration::from_secs(30)))
4546
.layer(sentry_tower::NewSentryLayer::new_from_top())
4647
.layer(sentry_tower::SentryHttpLayer::with_transaction())
48+
.layer(from_fn(self::real_ip::middleware))
4749
.layer(from_fn(log_request::log_requests))
4850
.layer(CatchPanicLayer::new())
4951
.layer(from_fn_with_state(

src/middleware/log_request.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
use crate::controllers::util::RequestPartsExt;
55
use crate::headers::XRequestId;
66
use crate::middleware::normalize_path::OriginalPath;
7-
use crate::real_ip::process_xff_headers;
7+
use crate::middleware::real_ip::RealIp;
88
use axum::headers::UserAgent;
99
use axum::middleware::Next;
1010
use axum::response::IntoResponse;
1111
use axum::{Extension, TypedHeader};
1212
use http::{Method, Request, StatusCode, Uri};
1313
use parking_lot::Mutex;
1414
use std::fmt::{self, Display, Formatter};
15-
use std::net::IpAddr;
1615
use std::ops::Deref;
1716
use std::sync::Arc;
1817
use std::time::{Duration, Instant};
@@ -30,6 +29,7 @@ pub struct RequestMetadata {
3029
method: Method,
3130
uri: Uri,
3231
original_path: Option<Extension<OriginalPath>>,
32+
real_ip: Extension<RealIp>,
3333
user_agent: TypedHeader<UserAgent>,
3434
request_id: Option<TypedHeader<XRequestId>>,
3535
}
@@ -40,7 +40,6 @@ pub struct Metadata<'a> {
4040
cause: Option<&'a CauseField>,
4141
error: Option<&'a ErrorField>,
4242
duration: Duration,
43-
real_ip: Option<IpAddr>,
4443
custom_metadata: RequestLog,
4544
}
4645

@@ -73,8 +72,7 @@ impl Display for Metadata<'_> {
7372
};
7473
}
7574

76-
let real_ip = self.real_ip.map(|ip| ip.to_string()).unwrap_or_default();
77-
line.add_quoted_field("ip", &real_ip)?;
75+
line.add_quoted_field("ip", **self.request.real_ip)?;
7876

7977
let response_time_in_ms = self.duration.as_millis();
8078
if !is_download_redirect || response_time_in_ms > 0 {
@@ -122,8 +120,6 @@ pub async fn log_requests<B>(
122120
let custom_metadata = RequestLog::default();
123121
req.extensions_mut().insert(custom_metadata.clone());
124122

125-
let real_ip = process_xff_headers(req.headers());
126-
127123
let response = next.run(req).await;
128124

129125
let metadata = Metadata {
@@ -132,7 +128,6 @@ pub async fn log_requests<B>(
132128
cause: response.extensions().get(),
133129
error: response.extensions().get(),
134130
duration: start_instant.elapsed(),
135-
real_ip,
136131
custom_metadata,
137132
};
138133

src/middleware/real_ip.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::real_ip::process_xff_headers;
2+
use axum::extract::ConnectInfo;
3+
use axum::middleware::Next;
4+
use axum::response::IntoResponse;
5+
use http::Request;
6+
use std::net::{IpAddr, SocketAddr};
7+
8+
#[derive(Copy, Clone, Debug, Deref)]
9+
pub struct RealIp(IpAddr);
10+
11+
pub async fn middleware<B>(
12+
ConnectInfo(socket_addr): ConnectInfo<SocketAddr>,
13+
mut req: Request<B>,
14+
next: Next<B>,
15+
) -> impl IntoResponse {
16+
let xff_ip = process_xff_headers(req.headers());
17+
let real_ip = xff_ip.unwrap_or_else(|| socket_addr.ip());
18+
19+
req.extensions_mut().insert(RealIp(real_ip));
20+
21+
next.run(req).await
22+
}

src/tests/util.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ use crates_io::models::{ApiToken, CreatedApiToken, User};
2929
use http::{Method, Request};
3030

3131
use axum::body::Bytes;
32+
use axum::extract::connect_info::MockConnectInfo;
3233
use chrono::NaiveDateTime;
3334
use cookie::Cookie;
3435
use crates_io::models::token::{CrateScope, EndpointScope};
3536
use crates_io::util::token::PlainToken;
3637
use http::header;
3738
use secrecy::ExposeSecret;
3839
use std::collections::HashMap;
40+
use std::net::SocketAddr;
3941
use tower_service::Service;
4042

4143
mod chaosproxy;
@@ -92,7 +94,12 @@ pub trait RequestHelper {
9294
/// Run a request that is expected to succeed
9395
#[track_caller]
9496
fn run<T>(&self, request: MockRequest) -> Response<T> {
95-
let mut router = self.app().router().clone();
97+
let router = self.app().router().clone();
98+
99+
// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
100+
// extractor has something to extract.
101+
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
102+
let mut router = router.layer(MockConnectInfo(mocket_addr));
96103

97104
let rt = tokio::runtime::Builder::new_current_thread()
98105
.enable_all()

0 commit comments

Comments
 (0)