Skip to content

Commit 50b48fe

Browse files
authored
feat(middleware-flexible-checksums): support trailing checksums with aws-chunked encoding (#3347)
1 parent d9e4c4b commit 50b48fe

File tree

3 files changed

+97
-27
lines changed

3 files changed

+97
-27
lines changed

packages/middleware-flexible-checksums/src/configuration.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Encoder, HashConstructor, StreamHasher } from "@aws-sdk/types";
1+
import { Encoder, GetAwsChunkedEncodingStream, HashConstructor, StreamHasher } from "@aws-sdk/types";
22

33
export interface PreviouslyResolved {
44
/**
@@ -7,6 +7,16 @@ export interface PreviouslyResolved {
77
*/
88
base64Encoder: Encoder;
99

10+
/**
11+
* A function that can calculate the length of a body.
12+
*/
13+
bodyLengthChecker: (body: any) => number | undefined;
14+
15+
/**
16+
* A function that returns Readable Stream which follows aws-chunked encoding stream.
17+
*/
18+
getAwsChunkedEncodingStream: GetAwsChunkedEncodingStream;
19+
1020
/**
1121
* A constructor for a class implementing the {@link Hash} interface that computes MD5 hashes.
1222
* @internal

packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest
99
import { getChecksumLocationName } from "./getChecksumLocationName";
1010
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
1111
import { hasHeader } from "./hasHeader";
12+
import { isStreaming } from "./isStreaming";
1213
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
1314
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";
1415

1516
jest.mock("@aws-sdk/protocol-http");
17+
jest.mock("./getChecksum");
1618
jest.mock("./getChecksumAlgorithmForRequest");
1719
jest.mock("./getChecksumLocationName");
18-
jest.mock("./selectChecksumAlgorithmFunction");
19-
jest.mock("./getChecksum");
2020
jest.mock("./hasHeader");
21+
jest.mock("./isStreaming");
2122
jest.mock("./validateChecksumFromResponse");
23+
jest.mock("./selectChecksumAlgorithmFunction");
2224

2325
describe(flexibleChecksumsMiddleware.name, () => {
2426
const mockNext = jest.fn();
@@ -31,8 +33,8 @@ describe(flexibleChecksumsMiddleware.name, () => {
3133
const mockConfig = {} as PreviouslyResolved;
3234
const mockMiddlewareConfig = { input: mockInput } as FlexibleChecksumsMiddlewareConfig;
3335

34-
const mockBody = {};
35-
const mockHeaders = {};
36+
const mockBody = { body: "mockBody" };
37+
const mockHeaders = { "content-length": 100 };
3638
const mockRequest = { body: mockBody, headers: mockHeaders };
3739
const mockArgs = { request: mockRequest } as BuildHandlerArguments<any>;
3840
const mockResult = { response: {} };
@@ -41,19 +43,20 @@ describe(flexibleChecksumsMiddleware.name, () => {
4143
mockNext.mockResolvedValueOnce(mockResult);
4244
const { isInstance } = HttpRequest;
4345
(isInstance as unknown as jest.Mock).mockReturnValue(true);
46+
(getChecksum as jest.Mock).mockReturnValue(mockChecksum);
4447
(getChecksumAlgorithmForRequest as jest.Mock).mockReturnValue(ChecksumAlgorithm.MD5);
4548
(getChecksumLocationName as jest.Mock).mockReturnValue(mockChecksumLocationName);
46-
(selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction);
47-
(getChecksum as jest.Mock).mockReturnValue(mockChecksum);
4849
(hasHeader as jest.Mock).mockReturnValue(false);
50+
(isStreaming as jest.Mock).mockReturnValue(false);
51+
(selectChecksumAlgorithmFunction as jest.Mock).mockReturnValue(mockChecksumAlgorithmFunction);
4952
});
5053

5154
afterEach(() => {
5255
expect(mockNext).toHaveBeenCalledTimes(1);
5356
jest.clearAllMocks();
5457
});
5558

56-
describe("skips checksum computation", () => {
59+
describe("skips", () => {
5760
it("if not an instance of HttpRequest", async () => {
5861
const { isInstance } = HttpRequest;
5962
(isInstance as unknown as jest.Mock).mockReturnValue(false);
@@ -65,7 +68,6 @@ describe(flexibleChecksumsMiddleware.name, () => {
6568
describe("request checksum", () => {
6669
afterEach(() => {
6770
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
68-
expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled();
6971
expect(getChecksum).not.toHaveBeenCalled();
7072
});
7173

@@ -75,6 +77,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
7577
await handler(mockArgs);
7678
expect(getChecksumLocationName).not.toHaveBeenCalled();
7779
expect(mockNext).toHaveBeenCalledWith(mockArgs);
80+
expect(selectChecksumAlgorithmFunction).not.toHaveBeenCalled();
7881
});
7982

8083
it("if header is already present", async () => {
@@ -87,6 +90,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
8790
(hasHeader as jest.Mock).mockReturnValue(true);
8891
await handler(mockArgsWithChecksumHeader);
8992
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
93+
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
9094
expect(hasHeader).toHaveBeenCalledTimes(1);
9195
expect(mockNext).toHaveBeenCalledWith(mockArgsWithChecksumHeader);
9296
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeadersWithChecksumHeader);
@@ -112,21 +116,57 @@ describe(flexibleChecksumsMiddleware.name, () => {
112116
});
113117
});
114118

115-
it("adds checksum in the request header", async () => {
116-
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
117-
await handler(mockArgs);
118-
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
119-
expect(hasHeader).toHaveBeenCalledTimes(1);
120-
expect(mockNext).toHaveBeenCalledWith({
121-
...mockArgs,
122-
request: {
123-
...mockRequest,
124-
headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum },
125-
},
119+
describe("adds checksum in the request header", () => {
120+
afterEach(() => {
121+
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
122+
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
123+
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
124+
});
125+
126+
it("for streaming body", async () => {
127+
(isStreaming as jest.Mock).mockReturnValue(true);
128+
const mockUpdatedBody = { body: "mockUpdatedBody" };
129+
const mockGetAwsChunkedEncodingStream = jest.fn().mockReturnValue(mockUpdatedBody);
130+
131+
const handler = flexibleChecksumsMiddleware(
132+
{ ...mockConfig, getAwsChunkedEncodingStream: mockGetAwsChunkedEncodingStream },
133+
mockMiddlewareConfig
134+
)(mockNext, {});
135+
await handler(mockArgs);
136+
137+
expect(mockNext).toHaveBeenCalledWith({
138+
...mockArgs,
139+
request: {
140+
...mockRequest,
141+
headers: {
142+
...mockHeaders,
143+
"content-length": undefined,
144+
"content-encoding": "aws-chunked",
145+
"transfer-encoding": "chunked",
146+
"x-amz-decoded-content-length": mockHeaders["content-length"],
147+
"x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
148+
"x-amz-trailer": mockChecksumLocationName,
149+
},
150+
body: mockUpdatedBody,
151+
},
152+
});
153+
expect(mockGetAwsChunkedEncodingStream).toHaveBeenCalledTimes(1);
154+
});
155+
156+
it("for non-streaming body", async () => {
157+
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
158+
await handler(mockArgs);
159+
expect(hasHeader).toHaveBeenCalledTimes(1);
160+
expect(mockNext).toHaveBeenCalledWith({
161+
...mockArgs,
162+
request: {
163+
...mockRequest,
164+
headers: { ...mockHeaders, [mockChecksumLocationName]: mockChecksum },
165+
},
166+
});
167+
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders);
168+
expect(getChecksum).toHaveBeenCalledTimes(1);
126169
});
127-
expect(hasHeader).toHaveBeenCalledWith(mockChecksumLocationName, mockHeaders);
128-
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
129-
expect(getChecksum).toHaveBeenCalledTimes(1);
130170
});
131171

132172
it("validates checksum from the response header", async () => {

packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest
1313
import { getChecksumLocationName } from "./getChecksumLocationName";
1414
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
1515
import { hasHeader } from "./hasHeader";
16+
import { isStreaming } from "./isStreaming";
1617
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
1718
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";
1819

@@ -26,20 +27,38 @@ export const flexibleChecksumsMiddleware =
2627

2728
const { request } = args;
2829
const { body: requestBody, headers } = request;
29-
const { streamHasher, base64Encoder } = config;
30+
const { base64Encoder, streamHasher } = config;
3031
const { input, requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;
3132

3233
const checksumAlgorithm = getChecksumAlgorithmForRequest(input, {
3334
requestChecksumRequired,
3435
requestAlgorithmMember,
3536
});
37+
let updatedBody = requestBody;
3638
let updatedHeaders = headers;
3739

3840
if (checksumAlgorithm) {
3941
const checksumLocationName = getChecksumLocationName(checksumAlgorithm);
40-
// ToDo: Update trailer instead if it is Unsigned-payload.
41-
if (!hasHeader(checksumLocationName, headers)) {
42-
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config);
42+
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(checksumAlgorithm, config);
43+
if (isStreaming(requestBody)) {
44+
const { getAwsChunkedEncodingStream, bodyLengthChecker } = config;
45+
updatedBody = getAwsChunkedEncodingStream(requestBody, {
46+
base64Encoder,
47+
bodyLengthChecker,
48+
checksumLocationName,
49+
checksumAlgorithmFn,
50+
streamHasher,
51+
});
52+
updatedHeaders = {
53+
...headers,
54+
"content-encoding": "aws-chunked",
55+
"transfer-encoding": "chunked",
56+
"x-amz-decoded-content-length": headers["content-length"],
57+
"x-amz-content-sha256": "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
58+
"x-amz-trailer": checksumLocationName,
59+
};
60+
delete updatedHeaders["content-length"];
61+
} else if (!hasHeader(checksumLocationName, headers)) {
4362
const checksum = await getChecksum(requestBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
4463
updatedHeaders = {
4564
...headers,
@@ -53,6 +72,7 @@ export const flexibleChecksumsMiddleware =
5372
request: {
5473
...request,
5574
headers: updatedHeaders,
75+
body: updatedBody,
5676
},
5777
});
5878

0 commit comments

Comments
 (0)