Skip to content

Commit 41aa421

Browse files
committed
Polish WebFlux ForwardedHeaderFilter and tests
Preparation for SPR-17072
1 parent 02403f6 commit 41aa421

File tree

2 files changed

+91
-68
lines changed

2 files changed

+91
-68
lines changed

spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
*/
4545
public class ForwardedHeaderFilter implements WebFilter {
4646

47-
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
47+
static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
4848

4949
static {
5050
FORWARDED_HEADER_NAMES.add("Forwarded");
@@ -72,54 +72,58 @@ public void setRemoveOnly(boolean removeOnly) {
7272
@Override
7373
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
7474

75-
if (shouldNotFilter(exchange.getRequest())) {
75+
ServerHttpRequest request = exchange.getRequest();
76+
if (!hasForwardedHeaders(request)) {
7677
return chain.filter(exchange);
7778
}
7879

7980
ServerWebExchange mutatedExchange;
80-
8181
if (this.removeOnly) {
82-
mutatedExchange = exchange.mutate().request(builder ->
83-
builder.headers(headers -> {
84-
FORWARDED_HEADER_NAMES.forEach(headers::remove);
85-
}))
86-
.build();
82+
mutatedExchange = exchange.mutate().request(this::removeForwardedHeaders).build();
8783
}
8884
else {
89-
URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri();
90-
String prefix = getForwardedPrefix(exchange.getRequest().getHeaders());
91-
92-
mutatedExchange = exchange.mutate().request(builder -> {
93-
builder.uri(uri);
94-
if (prefix != null) {
95-
builder.path(prefix + uri.getPath());
96-
builder.contextPath(prefix);
97-
}
98-
}).build();
85+
mutatedExchange = exchange.mutate()
86+
.request(builder -> {
87+
URI uri = UriComponentsBuilder.fromHttpRequest(request).build().toUri();
88+
builder.uri(uri);
89+
String prefix = getForwardedPrefix(request);
90+
if (prefix != null) {
91+
builder.path(prefix + uri.getPath());
92+
builder.contextPath(prefix);
93+
}
94+
})
95+
.build();
9996
}
10097

10198
return chain.filter(mutatedExchange);
10299
}
103100

104-
private boolean shouldNotFilter(ServerHttpRequest request) {
101+
private boolean hasForwardedHeaders(ServerHttpRequest request) {
105102
HttpHeaders headers = request.getHeaders();
106103
for (String headerName : FORWARDED_HEADER_NAMES) {
107104
if (headers.containsKey(headerName)) {
108-
return false;
105+
return true;
109106
}
110107
}
111-
return true;
108+
return false;
112109
}
113110

114111
@Nullable
115-
private static String getForwardedPrefix(HttpHeaders headers) {
112+
private static String getForwardedPrefix(ServerHttpRequest request) {
113+
HttpHeaders headers = request.getHeaders();
116114
String prefix = headers.getFirst("X-Forwarded-Prefix");
117115
if (prefix != null) {
118-
while (prefix.endsWith("/")) {
119-
prefix = prefix.substring(0, prefix.length() - 1);
120-
}
116+
int endIndex = prefix.length();
117+
while (endIndex > 1 && prefix.charAt(endIndex - 1) == '/') {
118+
endIndex--;
119+
};
120+
prefix = endIndex != prefix.length() ? prefix.substring(0, endIndex) : prefix;
121121
}
122122
return prefix;
123123
}
124124

125+
private ServerHttpRequest.Builder removeForwardedHeaders(ServerHttpRequest.Builder builder) {
126+
return builder.headers(map -> FORWARDED_HEADER_NAMES.forEach(map::remove));
127+
}
128+
125129
}

spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,19 @@
2323
import reactor.core.publisher.Mono;
2424

2525
import org.springframework.http.HttpHeaders;
26+
import org.springframework.http.server.reactive.ServerHttpRequest;
2627
import org.springframework.lang.Nullable;
28+
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
2729
import org.springframework.mock.web.test.server.MockServerWebExchange;
2830
import org.springframework.web.server.ServerWebExchange;
2931
import org.springframework.web.server.WebFilterChain;
3032

3133
import static org.junit.Assert.*;
32-
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*;
3334

3435
/**
36+
* Unit tests for {@link ForwardedHeaderFilter}.
3537
* @author Arjen Poutsma
38+
* @author Rossen Stoyanchev
3639
*/
3740
public class ForwardedHeaderFilterTests {
3841

@@ -46,65 +49,65 @@ public class ForwardedHeaderFilterTests {
4649

4750
@Test
4851
public void removeOnly() {
49-
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
50-
.header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43")
51-
.header("X-Forwarded-Host", "example.com")
52-
.header("X-Forwarded-Port", "8080")
53-
.header("X-Forwarded-Proto", "http")
54-
.header("X-Forwarded-Prefix", "prefix")
55-
.header("X-Forwarded-Ssl", "on"));
5652

5753
this.filter.setRemoveOnly(true);
58-
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
59-
60-
HttpHeaders result = this.filterChain.getHeaders();
61-
assertNotNull(result);
62-
assertFalse(result.containsKey("Forwarded"));
63-
assertFalse(result.containsKey("X-Forwarded-Host"));
64-
assertFalse(result.containsKey("X-Forwarded-Port"));
65-
assertFalse(result.containsKey("X-Forwarded-Proto"));
66-
assertFalse(result.containsKey("X-Forwarded-Prefix"));
67-
assertFalse(result.containsKey("X-Forwarded-Ssl"));
54+
55+
HttpHeaders headers = new HttpHeaders();
56+
headers.add("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43");
57+
headers.add("X-Forwarded-Host", "example.com");
58+
headers.add("X-Forwarded-Port", "8080");
59+
headers.add("X-Forwarded-Proto", "http");
60+
headers.add("X-Forwarded-Prefix", "prefix");
61+
headers.add("X-Forwarded-Ssl", "on");
62+
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
63+
64+
this.filterChain.assertForwardedHeadersRemoved();
6865
}
6966

7067
@Test
71-
public void xForwardedRequest() throws Exception {
72-
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
73-
.header("X-Forwarded-Host", "84.198.58.199")
74-
.header("X-Forwarded-Port", "443")
75-
.header("X-Forwarded-Proto", "https"));
76-
77-
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange));
68+
public void xForwardedHeaders() throws Exception {
69+
HttpHeaders headers = new HttpHeaders();
70+
headers.add("X-Forwarded-Host", "84.198.58.199");
71+
headers.add("X-Forwarded-Port", "443");
72+
headers.add("X-Forwarded-Proto", "https");
73+
headers.add("foo", "bar");
74+
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
75+
76+
assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
7877
}
7978

8079
@Test
81-
public void forwardedRequest() throws Exception {
82-
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
83-
.header("Forwarded", "host=84.198.58.199;proto=https"));
80+
public void forwardedHeader() throws Exception {
81+
HttpHeaders headers = new HttpHeaders();
82+
headers.add("Forwarded", "host=84.198.58.199;proto=https");
83+
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
8484

85-
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange));
85+
assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
8686
}
8787

8888
@Test
89-
public void requestUriWithForwardedPrefix() throws Exception {
90-
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
91-
.header("X-Forwarded-Prefix", "/prefix"));
89+
public void xForwardedPrefix() throws Exception {
90+
HttpHeaders headers = new HttpHeaders();
91+
headers.add("X-Forwarded-Prefix", "/prefix");
92+
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
9293

93-
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange));
94+
assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
95+
assertEquals("/prefix/path", this.filterChain.requestPathValue);
9496
}
9597

9698
@Test
97-
public void requestUriWithForwardedPrefixTrailingSlash() throws Exception {
98-
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
99-
.header("X-Forwarded-Prefix", "/prefix/"));
99+
public void xForwardedPrefixTrailingSlash() throws Exception {
100+
HttpHeaders headers = new HttpHeaders();
101+
headers.add("X-Forwarded-Prefix", "/prefix////");
102+
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
100103

101-
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange));
104+
assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
105+
assertEquals("/prefix/path", this.filterChain.requestPathValue);
102106
}
103107

104-
@Nullable
105-
private URI filterAndGetUri(ServerWebExchange exchange) {
106-
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
107-
return this.filterChain.uri;
108+
private MockServerWebExchange getExchange(HttpHeaders headers) {
109+
MockServerHttpRequest request = MockServerHttpRequest.get(BASE_URL).headers(headers).build();
110+
return MockServerWebExchange.from(request);
108111
}
109112

110113

@@ -116,21 +119,37 @@ private static class TestWebFilterChain implements WebFilterChain {
116119
@Nullable
117120
private URI uri;
118121

122+
@Nullable String requestPathValue;
123+
119124

120125
@Nullable
121126
public HttpHeaders getHeaders() {
122127
return this.headers;
123128
}
124129

130+
@Nullable
131+
public String getHeader(String name) {
132+
assertNotNull(this.headers);
133+
return this.headers.getFirst(name);
134+
}
135+
136+
public void assertForwardedHeadersRemoved() {
137+
assertNotNull(this.headers);
138+
ForwardedHeaderFilter.FORWARDED_HEADER_NAMES
139+
.forEach(name -> assertFalse(this.headers.containsKey(name)));
140+
}
141+
125142
@Nullable
126143
public URI getUri() {
127144
return this.uri;
128145
}
129146

130147
@Override
131148
public Mono<Void> filter(ServerWebExchange exchange) {
132-
this.headers = exchange.getRequest().getHeaders();
133-
this.uri = exchange.getRequest().getURI();
149+
ServerHttpRequest request = exchange.getRequest();
150+
this.headers = request.getHeaders();
151+
this.uri = request.getURI();
152+
this.requestPathValue = request.getPath().value();
134153
return Mono.empty();
135154
}
136155
}

0 commit comments

Comments
 (0)