Skip to content

Commit 4c6ee1c

Browse files
committed
Fixed an issue in ChecksumCalculatingAsyncRequestBody where the position of the ByteBuffer was not honored.
1 parent a803f9d commit 4c6ee1c

File tree

4 files changed

+194
-120
lines changed

4 files changed

+194
-120
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ private static final class SynchronousChunkBuffer {
239239
}
240240

241241
private Iterable<ByteBuffer> buffer(ByteBuffer bytes) {
242-
return chunkBuffer.bufferAndCreateChunks(bytes);
242+
return chunkBuffer.split(bytes);
243243
}
244244
}
245245

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
import java.nio.ByteBuffer;
2121
import java.util.ArrayList;
22+
import java.util.Arrays;
23+
import java.util.Collections;
2224
import java.util.List;
2325
import java.util.concurrent.atomic.AtomicLong;
2426
import software.amazon.awssdk.annotations.SdkInternalApi;
25-
import software.amazon.awssdk.utils.BinaryUtils;
27+
import software.amazon.awssdk.utils.Logger;
2628
import software.amazon.awssdk.utils.Validate;
2729
import software.amazon.awssdk.utils.builder.SdkBuilder;
2830

@@ -31,70 +33,115 @@
3133
*/
3234
@SdkInternalApi
3335
public final class ChunkBuffer {
34-
private final AtomicLong remainingBytes;
36+
private static final Logger log = Logger.loggerFor(ChunkBuffer.class);
37+
private final AtomicLong transferredBytes;
3538
private final ByteBuffer currentBuffer;
36-
private final int bufferSize;
39+
private final int chunkSize;
40+
private final long totalBytes;
3741

3842
private ChunkBuffer(Long totalBytes, Integer bufferSize) {
3943
Validate.notNull(totalBytes, "The totalBytes must not be null");
4044

4145
int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE;
42-
this.bufferSize = chunkSize;
46+
this.chunkSize = chunkSize;
4347
this.currentBuffer = ByteBuffer.allocate(chunkSize);
44-
this.remainingBytes = new AtomicLong(totalBytes);
48+
this.totalBytes = totalBytes;
49+
this.transferredBytes = new AtomicLong(0);
4550
}
4651

4752
public static Builder builder() {
4853
return new DefaultBuilder();
4954
}
5055

5156

52-
// currentBuffer and bufferedList can get over written if concurrent Threads calls this method at the same time.
53-
public synchronized Iterable<ByteBuffer> bufferAndCreateChunks(ByteBuffer buffer) {
54-
int startPosition = 0;
55-
List<ByteBuffer> bufferedList = new ArrayList<>();
56-
int currentBytesRead = buffer.remaining();
57-
do {
58-
int bufferedBytes = currentBuffer.position();
59-
int availableToRead = bufferSize - bufferedBytes;
60-
int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition);
57+
/**
58+
* Split the input {@link ByteBuffer} into multiple smaller {@link ByteBuffer}s, each of which contains {@link #chunkSize}
59+
* worth of bytes. If the last chunk of the input ByteBuffer contains less than {@link #chunkSize} data, the last chunk will
60+
* be buffered.
61+
*/
62+
public synchronized Iterable<ByteBuffer> split(ByteBuffer inputByteBuffer) {
6163

62-
byte[] bytes = BinaryUtils.copyAllBytesFrom(buffer);
63-
if (bufferedBytes == 0) {
64-
currentBuffer.put(bytes, startPosition, bytesToMove);
65-
} else {
66-
currentBuffer.put(bytes, 0, bytesToMove);
64+
if (!inputByteBuffer.hasRemaining()) {
65+
return Collections.singletonList(inputByteBuffer);
66+
}
67+
68+
List<ByteBuffer> byteBuffers = new ArrayList<>();
69+
70+
// If current buffer is not empty, fill the buffer first.
71+
if (currentBuffer.position() != 0) {
72+
fillBuffer(inputByteBuffer);
73+
74+
if (isCurrentBufferFull()) {
75+
addCurrentBufferToIterable(byteBuffers, chunkSize);
76+
}
77+
}
78+
79+
// If the input buffer is not empty, split the input buffer
80+
if (inputByteBuffer.hasRemaining()) {
81+
splitInputBuffer(inputByteBuffer, byteBuffers);
82+
}
83+
84+
// If this is the last chunk, add data buffered to the iterable
85+
if (isLastChunk()) {
86+
int remainingBytesInBuffer = currentBuffer.position();
87+
addCurrentBufferToIterable(byteBuffers, remainingBytesInBuffer);
88+
}
89+
return byteBuffers;
90+
}
91+
92+
private boolean isCurrentBufferFull() {
93+
return currentBuffer.position() == chunkSize;
94+
}
95+
96+
private void splitInputBuffer(ByteBuffer buffer, List<ByteBuffer> byteBuffers) {
97+
while (buffer.hasRemaining()) {
98+
ByteBuffer chunkByteBuffer = buffer.asReadOnlyBuffer();
99+
if (buffer.remaining() < chunkSize) {
100+
currentBuffer.put(buffer);
101+
break;
67102
}
68103

69-
startPosition = startPosition + bytesToMove;
70-
71-
// Send the data once the buffer is full
72-
if (currentBuffer.position() == bufferSize) {
73-
currentBuffer.position(0);
74-
ByteBuffer bufferToSend = ByteBuffer.allocate(bufferSize);
75-
bufferToSend.put(currentBuffer.array(), 0, bufferSize);
76-
bufferToSend.clear();
77-
currentBuffer.clear();
78-
bufferedList.add(bufferToSend);
79-
remainingBytes.addAndGet(-bufferSize);
104+
int newLimit = chunkByteBuffer.position() + chunkSize;
105+
chunkByteBuffer.limit(newLimit);
106+
buffer.position(newLimit);
107+
byteBuffers.add(chunkByteBuffer);
108+
transferredBytes.addAndGet(chunkSize);
109+
}
110+
}
111+
112+
private boolean isLastChunk() {
113+
long remainingBytes = totalBytes - transferredBytes.get();
114+
return remainingBytes != 0 && remainingBytes == currentBuffer.position();
115+
}
116+
117+
private void addCurrentBufferToIterable(List<ByteBuffer> byteBuffers, int capacity) {
118+
ByteBuffer bufferedChunk = ByteBuffer.allocate(capacity);
119+
currentBuffer.flip();
120+
bufferedChunk.put(currentBuffer);
121+
bufferedChunk.flip();
122+
byteBuffers.add(bufferedChunk);
123+
transferredBytes.addAndGet(bufferedChunk.remaining());
124+
currentBuffer.clear();
125+
}
126+
127+
private void fillBuffer(ByteBuffer inputByteBuffer) {
128+
while (currentBuffer.position() < chunkSize) {
129+
if (!inputByteBuffer.hasRemaining()) {
130+
break;
131+
}
132+
133+
int remainingCapacity = chunkSize - currentBuffer.position();
134+
135+
if (inputByteBuffer.remaining() < remainingCapacity) {
136+
currentBuffer.put(inputByteBuffer);
137+
} else {
138+
ByteBuffer remainingChunk = inputByteBuffer.asReadOnlyBuffer();
139+
int newLimit = inputByteBuffer.position() + remainingCapacity;
140+
remainingChunk.limit(newLimit);
141+
inputByteBuffer.position(newLimit);
142+
currentBuffer.put(remainingChunk);
80143
}
81-
} while (startPosition < currentBytesRead);
82-
83-
int remainingBytesInBuffer = currentBuffer.position();
84-
85-
// Send the remaining buffer when
86-
// 1. remainingBytes in buffer are same as the last few bytes to be read.
87-
// 2. If it is a zero byte and the last byte to be read.
88-
if (remainingBytes.get() == remainingBytesInBuffer &&
89-
(buffer.remaining() == 0 || remainingBytesInBuffer > 0)) {
90-
currentBuffer.clear();
91-
ByteBuffer trimmedBuffer = ByteBuffer.allocate(remainingBytesInBuffer);
92-
trimmedBuffer.put(currentBuffer.array(), 0, remainingBytesInBuffer);
93-
trimmedBuffer.clear();
94-
bufferedList.add(trimmedBuffer);
95-
remainingBytes.addAndGet(-remainingBytesInBuffer);
96144
}
97-
return bufferedList;
98145
}
99146

100147
public interface Builder extends SdkBuilder<Builder, ChunkBuffer> {

core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import static org.assertj.core.api.Assertions.assertThat;
1919
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2020

21+
import java.io.ByteArrayInputStream;
22+
import java.io.IOException;
2123
import java.nio.ByteBuffer;
2224
import java.nio.charset.StandardCharsets;
2325
import java.util.ArrayList;
@@ -29,8 +31,12 @@
2931
import java.util.concurrent.atomic.AtomicInteger;
3032
import java.util.stream.Collectors;
3133
import java.util.stream.IntStream;
34+
import org.apache.commons.lang3.RandomStringUtils;
3235
import org.junit.jupiter.api.Test;
36+
import org.junit.jupiter.params.ParameterizedTest;
37+
import org.junit.jupiter.params.provider.ValueSource;
3338
import software.amazon.awssdk.core.internal.async.ChunkBuffer;
39+
import software.amazon.awssdk.utils.BinaryUtils;
3440
import software.amazon.awssdk.utils.StringUtils;
3541

3642
class ChunkBufferTest {
@@ -40,42 +46,38 @@ void builderWithNoTotalSize() {
4046
assertThatThrownBy(() -> ChunkBuffer.builder().build()).isInstanceOf(NullPointerException.class);
4147
}
4248

43-
@Test
44-
void numberOfChunkMultipleOfTotalBytes() {
45-
String inputString = StringUtils.repeat("*", 25);
46-
47-
ChunkBuffer chunkBuffer =
48-
ChunkBuffer.builder().bufferSize(5).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
49-
Iterable<ByteBuffer> byteBuffers =
50-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
51-
52-
AtomicInteger iteratedCounts = new AtomicInteger();
53-
byteBuffers.forEach(r -> {
54-
iteratedCounts.getAndIncrement();
55-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 5).getBytes(StandardCharsets.UTF_8));
56-
});
57-
assertThat(iteratedCounts.get()).isEqualTo(5);
58-
}
59-
60-
@Test
61-
void numberOfChunk_Not_MultipleOfTotalBytes() {
62-
int totalBytes = 23;
49+
@ParameterizedTest
50+
@ValueSource(ints = {1, 6, 10, 23, 25})
51+
void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) {
6352
int bufferSize = 5;
6453

65-
String inputString = StringUtils.repeat("*", totalBytes);
54+
String inputString = RandomStringUtils.randomAscii(totalBytes);
6655
ChunkBuffer chunkBuffer =
6756
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
6857
Iterable<ByteBuffer> byteBuffers =
69-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
58+
chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
59+
60+
AtomicInteger index = new AtomicInteger(0);
61+
int count = (int) Math.ceil(totalBytes / (double) bufferSize);
62+
int remainder = totalBytes % bufferSize;
7063

71-
AtomicInteger iteratedCounts = new AtomicInteger();
7264
byteBuffers.forEach(r -> {
73-
iteratedCounts.getAndIncrement();
74-
if (iteratedCounts.get() * bufferSize < totalBytes) {
75-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", bufferSize).getBytes(StandardCharsets.UTF_8));
76-
} else {
77-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 3).getBytes(StandardCharsets.UTF_8));
65+
int i = index.get();
7866

67+
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))) {
68+
byte[] expected;
69+
if (i == count - 1 && remainder != 0) {
70+
expected = new byte[remainder];
71+
} else {
72+
expected = new byte[bufferSize];
73+
}
74+
inputStream.skip(i * bufferSize);
75+
inputStream.read(expected);
76+
byte[] actualBytes = BinaryUtils.copyBytesFrom(r);
77+
assertThat(actualBytes).isEqualTo(expected);
78+
index.incrementAndGet();
79+
} catch (IOException e) {
80+
throw new RuntimeException(e);
7981
}
8082
});
8183
}
@@ -86,7 +88,7 @@ void zeroTotalBytesAsInput_returnsZeroByte() {
8688
ChunkBuffer chunkBuffer =
8789
ChunkBuffer.builder().bufferSize(5).totalBytes(zeroByte.length).build();
8890
Iterable<ByteBuffer> byteBuffers =
89-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(zeroByte));
91+
chunkBuffer.split(ByteBuffer.wrap(zeroByte));
9092

9193
AtomicInteger iteratedCounts = new AtomicInteger();
9294
byteBuffers.forEach(r -> {
@@ -104,16 +106,16 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() {
104106
ChunkBuffer chunkBuffer =
105107
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining()).build();
106108
Iterable<ByteBuffer> byteBuffers =
107-
chunkBuffer.bufferAndCreateChunks(wrap);
109+
chunkBuffer.split(wrap);
108110

109111
AtomicInteger iteratedCounts = new AtomicInteger();
110112
byteBuffers.forEach(r -> {
111113
iteratedCounts.getAndIncrement();
112114
if (iteratedCounts.get() * bufferSize < totalBytes) {
113115
// array of empty bytes
114-
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(bufferSize).array());
116+
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(bufferSize).array());
115117
} else {
116-
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
118+
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
117119
}
118120
});
119121
assertThat(iteratedCounts.get()).isEqualTo(4);
@@ -167,7 +169,7 @@ void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException,
167169

168170
futures = IntStream.range(0, threads).<Future<Iterable>>mapToObj(t -> service.submit(() -> {
169171
String inputString = StringUtils.repeat(Integer.toString(counter.incrementAndGet()), totalBytes);
170-
return chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
172+
return chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
171173
})).collect(Collectors.toCollection(() -> new ArrayList<>(threads)));
172174

173175
AtomicInteger filledBuffers = new AtomicInteger(0);

0 commit comments

Comments
 (0)