Skip to content

Commit 7936d57

Browse files
authored
Add content header if the Request doesnot has one for CRT sync Http clients (#4920)
* Add content header if the Request doesnot has one for CRT sync Http clients * Content length updated in the Marshallers * Content length updation done in the Marshallers after internal comment * Handled Matts comments on PR * Updated MarshallersAddContentLengthTest to Junit5 * Updated to initiate theCheckbuilds * Handle sonar raised issues
1 parent c00e4c8 commit 7936d57

File tree

6 files changed

+164
-2
lines changed

6 files changed

+164
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "Add content-length header in Json and Xml Protocol Marshaller for String and Binary explicit Payloads."
6+
}

core/protocols/aws-json-protocol/src/main/java/software/amazon/awssdk/protocols/json/internal/marshall/JsonProtocolMarshaller.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,16 @@ void doMarshall(SdkPojo pojo) {
181181
Object val = field.getValueOrDefault(pojo);
182182
if (isExplicitBinaryPayload(field)) {
183183
if (val != null) {
184-
request.contentStreamProvider(((SdkBytes) val)::asInputStream);
184+
SdkBytes sdkBytes = (SdkBytes) val;
185+
request.contentStreamProvider(sdkBytes::asInputStream);
186+
updateContentLengthHeader(sdkBytes.asByteArrayUnsafe().length);
185187
}
186188
} else if (isExplicitStringPayload(field)) {
187189
if (val != null) {
188190
byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8);
189191
request.contentStreamProvider(() -> new ByteArrayInputStream(content));
192+
updateContentLengthHeader(content.length);
193+
190194
}
191195
} else if (isExplicitPayloadMember(field)) {
192196
marshallExplicitJsonPayload(field, val);
@@ -196,6 +200,10 @@ void doMarshall(SdkPojo pojo) {
196200
}
197201
}
198202

203+
private void updateContentLengthHeader(int contentLength) {
204+
request.putHeader(CONTENT_LENGTH, Integer.toString(contentLength));
205+
}
206+
199207
private boolean isExplicitBinaryPayload(SdkField<?> field) {
200208
return isExplicitPayloadMember(field) && MarshallingType.SDK_BYTES.equals(field.marshallingType());
201209
}

core/protocols/aws-xml-protocol/src/main/java/software/amazon/awssdk/protocols/xml/internal/marshall/XmlProtocolMarshaller.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ void doMarshall(SdkPojo pojo) {
9191
Object val = field.getValueOrDefault(pojo);
9292

9393
if (isBinary(field, val)) {
94-
request.contentStreamProvider(((SdkBytes) val)::asInputStream);
94+
SdkBytes sdkBytes = (SdkBytes) val;
95+
request.contentStreamProvider(sdkBytes::asInputStream);
9596
setContentTypeHeaderIfNeeded("binary/octet-stream");
97+
request.putHeader(CONTENT_LENGTH, Integer.toString(sdkBytes.asByteArrayUnsafe().length));
9698

9799
} else if (isExplicitPayloadMember(field) && val instanceof String) {
98100
byte[] content = ((String) val).getBytes(StandardCharsets.UTF_8);

test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-core-input.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,11 @@
442442
},
443443
"then": {
444444
"serializedAs": {
445+
"headers": {
446+
"contains": {
447+
"content-length": "8"
448+
}
449+
},
445450
"body": {
446451
"equals": "contents"
447452
}

test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/rest-json-input.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
},
5353
"then": {
5454
"serializedAs": {
55+
"headers": {
56+
"contains": {
57+
"Content-length": "22"
58+
}
59+
},
5560
"body": {
5661
"jsonEquals": "{\"StringMember\": \"foo\"}"
5762
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.protocol.tests.contentlength;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
20+
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
21+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
22+
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
23+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
24+
import static com.github.tomakehurst.wiremock.client.WireMock.verify;
25+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
26+
import static software.amazon.awssdk.http.Header.CONTENT_LENGTH;
27+
28+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
29+
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
30+
import java.net.URI;
31+
import java.nio.charset.StandardCharsets;
32+
import org.junit.jupiter.api.Test;
33+
import software.amazon.awssdk.core.SdkBytes;
34+
import software.amazon.awssdk.core.interceptor.Context;
35+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
36+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
37+
import software.amazon.awssdk.http.SdkHttpRequest;
38+
import software.amazon.awssdk.http.crt.AwsCrtHttpClient;
39+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
40+
import software.amazon.awssdk.services.protocolrestjson.model.OperationWithExplicitPayloadStructureResponse;
41+
import software.amazon.awssdk.services.protocolrestjson.model.SimpleStruct;
42+
import software.amazon.awssdk.services.protocolrestxml.ProtocolRestXmlClient;
43+
import software.amazon.awssdk.services.protocolrestxml.model.OperationWithExplicitPayloadStringResponse;
44+
45+
@WireMockTest
46+
public class MarshallersAddContentLengthTest {
47+
public static final String STRING_PAYLOAD = "TEST_STRING_PAYLOAD";
48+
49+
@Test
50+
void jsonMarshallers_AddContentLength_for_explicitBinaryPayload(WireMockRuntimeInfo wireMock) {
51+
stubSuccessfulResponse();
52+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
53+
ProtocolRestJsonClient client = ProtocolRestJsonClient.builder()
54+
.httpClient(AwsCrtHttpClient.builder().build())
55+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
56+
.endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort()))
57+
.build();
58+
client.operationWithExplicitPayloadBlob(p -> p.payloadMember(SdkBytes.fromString(STRING_PAYLOAD,
59+
StandardCharsets.UTF_8)));
60+
verify(postRequestedFor(anyUrl()).withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length()))));
61+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH))
62+
.contains(String.valueOf(STRING_PAYLOAD.length()));
63+
}
64+
65+
@Test
66+
void jsonMarshallers_AddContentLength_for_explicitStringPayload(WireMockRuntimeInfo wireMock) {
67+
stubSuccessfulResponse();
68+
String expectedPayload = String.format("{\"StringMember\":\"%s\"}", STRING_PAYLOAD);
69+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
70+
ProtocolRestJsonClient client = ProtocolRestJsonClient.builder()
71+
.httpClient(AwsCrtHttpClient.builder().build())
72+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
73+
.endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort()))
74+
.build();
75+
OperationWithExplicitPayloadStructureResponse response =
76+
client.operationWithExplicitPayloadStructure(p -> p.payloadMember(SimpleStruct.builder().stringMember(STRING_PAYLOAD).build()));
77+
verify(postRequestedFor(anyUrl())
78+
.withRequestBody(equalTo(expectedPayload))
79+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length()))));
80+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH))
81+
.contains(String.valueOf(expectedPayload.length()));
82+
}
83+
84+
@Test
85+
void xmlMarshallers_AddContentLength_for_explicitBinaryPayload(WireMockRuntimeInfo wireMock) {
86+
stubSuccessfulResponse();
87+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
88+
ProtocolRestXmlClient client = ProtocolRestXmlClient.builder()
89+
.httpClient(AwsCrtHttpClient.builder().build())
90+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
91+
.endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort()))
92+
.build();
93+
client.operationWithExplicitPayloadBlob(r -> r.payloadMember(SdkBytes.fromString(STRING_PAYLOAD,
94+
StandardCharsets.UTF_8)));
95+
verify(postRequestedFor(anyUrl()).withRequestBody(equalTo(STRING_PAYLOAD))
96+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(STRING_PAYLOAD.length()))));
97+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH))
98+
.contains(String.valueOf(STRING_PAYLOAD.length()));
99+
}
100+
101+
@Test
102+
void xmlMarshallers_AddContentLength_for_explicitStringPayload(WireMockRuntimeInfo wireMock) {
103+
stubSuccessfulResponse();
104+
String expectedPayload = STRING_PAYLOAD;
105+
CaptureRequestInterceptor captureRequestInterceptor = new CaptureRequestInterceptor();
106+
ProtocolRestXmlClient client = ProtocolRestXmlClient.builder()
107+
.httpClient(AwsCrtHttpClient.builder().build())
108+
.overrideConfiguration(o -> o.addExecutionInterceptor(captureRequestInterceptor))
109+
.endpointOverride(URI.create("http://localhost:" + wireMock.getHttpPort()))
110+
.build();
111+
OperationWithExplicitPayloadStringResponse stringResponse =
112+
client.operationWithExplicitPayloadString(p -> p.payloadMember(STRING_PAYLOAD));
113+
verify(postRequestedFor(anyUrl())
114+
.withRequestBody(equalTo(expectedPayload))
115+
.withHeader(CONTENT_LENGTH, equalTo(String.valueOf(expectedPayload.length()))));
116+
assertThat(captureRequestInterceptor.requestAfterMarshalling().firstMatchingHeader(CONTENT_LENGTH))
117+
.contains(String.valueOf(expectedPayload.length()));
118+
}
119+
120+
private void stubSuccessfulResponse() {
121+
stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200)));
122+
}
123+
124+
private static class CaptureRequestInterceptor implements ExecutionInterceptor {
125+
private SdkHttpRequest requestAfterMarshilling;
126+
127+
public SdkHttpRequest requestAfterMarshalling() {
128+
return requestAfterMarshilling;
129+
}
130+
131+
@Override
132+
public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) {
133+
this.requestAfterMarshilling = context.httpRequest();
134+
}
135+
}
136+
}

0 commit comments

Comments
 (0)