Skip to content

Commit b3923cb

Browse files
committed
Fix the race condition in the FileAsyncRequestBody which causes the request to hang
1 parent 4a924a8 commit b3923cb

File tree

5 files changed

+68
-73
lines changed

5 files changed

+68
-73
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import software.amazon.awssdk.core.async.AsyncRequestBody;
3232
import software.amazon.awssdk.core.internal.util.Mimetype;
3333
import software.amazon.awssdk.core.internal.util.NoopSubscription;
34+
import software.amazon.awssdk.utils.Logger;
3435
import software.amazon.awssdk.utils.builder.SdkBuilder;
3536

3637
/**
@@ -41,6 +42,7 @@
4142
*/
4243
@SdkInternalApi
4344
public final class FileAsyncRequestBody implements AsyncRequestBody {
45+
private static final Logger log = Logger.loggerFor(FileAsyncRequestBody.class);
4446

4547
/**
4648
* Default size (in bytes) of ByteBuffer chunks read from the file and delivered to the subscriber.
@@ -169,10 +171,11 @@ private static final class FileSubscription implements Subscription {
169171
private final Subscriber<? super ByteBuffer> subscriber;
170172
private final int chunkSize;
171173

172-
private long position = 0;
173-
private AtomicLong outstandingDemand = new AtomicLong(0);
174-
private boolean writeInProgress = false;
174+
private final AtomicLong position = new AtomicLong(0);
175+
private long outstandingDemand = 0;
176+
private boolean readInProgress = false;
175177
private volatile boolean done = false;
178+
private final Object lock = new Object();
176179

177180
private FileSubscription(AsynchronousFileChannel inputChannel, Subscriber<? super ByteBuffer> subscriber, int chunkSize) {
178181
this.inputChannel = inputChannel;
@@ -189,23 +192,24 @@ public void request(long n) {
189192
if (n < 1) {
190193
IllegalArgumentException ex =
191194
new IllegalArgumentException(subscriber + " violated the Reactive Streams rule 3.9 by requesting a "
192-
+ "non-positive number of elements.");
195+
+ "non-positive number of elements.");
193196
signalOnError(ex);
194197
} else {
195198
try {
196-
// As governed by rule 3.17, when demand overflows `Long.MAX_VALUE` we treat the signalled demand as
197-
// "effectively unbounded"
198-
outstandingDemand.getAndUpdate(initialDemand -> {
199-
if (Long.MAX_VALUE - initialDemand < n) {
200-
return Long.MAX_VALUE;
199+
// We need to synchronize here because of the race condition
200+
// where readData finishes reading at the same time request
201+
// demand comes in
202+
synchronized (lock) {
203+
// As governed by rule 3.17, when demand overflows `Long.MAX_VALUE` we treat the signalled demand as
204+
// "effectively unbounded"
205+
if (Long.MAX_VALUE - outstandingDemand < n) {
206+
outstandingDemand = Long.MAX_VALUE;
201207
} else {
202-
return initialDemand + n;
208+
outstandingDemand += n;
203209
}
204-
});
205210

206-
synchronized (this) {
207-
if (!writeInProgress) {
208-
writeInProgress = true;
211+
if (!readInProgress) {
212+
readInProgress = true;
209213
readData();
210214
}
211215
}
@@ -227,32 +231,33 @@ public void cancel() {
227231

228232
private void readData() {
229233
// It's possible to have another request for data come in after we've closed the file.
230-
if (!inputChannel.isOpen()) {
234+
if (!inputChannel.isOpen() || done) {
231235
return;
232236
}
233237

234238
ByteBuffer buffer = ByteBuffer.allocate(chunkSize);
235-
inputChannel.read(buffer, position, buffer, new CompletionHandler<Integer, ByteBuffer>() {
239+
inputChannel.read(buffer, position.get(), buffer, new CompletionHandler<Integer, ByteBuffer>() {
236240
@Override
237241
public void completed(Integer result, ByteBuffer attachment) {
242+
238243
if (result > 0) {
239244
attachment.flip();
240-
position += attachment.remaining();
245+
position.addAndGet(attachment.remaining());
241246
signalOnNext(attachment);
242-
// If we have more permits, queue up another read.
243-
if (outstandingDemand.decrementAndGet() > 0) {
244-
readData();
245-
return;
247+
248+
synchronized (lock) {
249+
// If we have more permits, queue up another read.
250+
if (--outstandingDemand > 0) {
251+
readData();
252+
} else {
253+
readInProgress = false;
254+
}
246255
}
247256
} else {
248257
// Reached the end of the file, notify the subscriber and cleanup
249258
signalOnComplete();
250259
closeFile();
251260
}
252-
253-
synchronized (FileSubscription.this) {
254-
writeInProgress = false;
255-
}
256261
}
257262

258263
@Override
@@ -267,32 +272,32 @@ private void closeFile() {
267272
try {
268273
inputChannel.close();
269274
} catch (IOException e) {
270-
signalOnError(e);
275+
log.warn(() -> "Failed to close the file", e);
271276
}
272277
}
273278

274-
private void signalOnNext(ByteBuffer bb) {
279+
private void signalOnNext(ByteBuffer attachment) {
275280
synchronized (this) {
276281
if (!done) {
277-
subscriber.onNext(bb);
282+
subscriber.onNext(attachment);
278283
}
279284
}
280285
}
281286

282287
private void signalOnComplete() {
283288
synchronized (this) {
284289
if (!done) {
285-
subscriber.onComplete();
286290
done = true;
291+
subscriber.onComplete();
287292
}
288293
}
289294
}
290295

291296
private void signalOnError(Throwable t) {
292297
synchronized (this) {
293298
if (!done) {
294-
subscriber.onError(t);
295299
done = true;
300+
subscriber.onError(t);
296301
}
297302
}
298303
}
@@ -301,4 +306,4 @@ private void signalOnError(Throwable t) {
301306
private static AsynchronousFileChannel openInputChannel(Path path) throws IOException {
302307
return AsynchronousFileChannel.open(path, StandardOpenOption.READ);
303308
}
304-
}
309+
}

test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3BaseStabilityTest.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
package software.amazon.awssdk.stability.tests.s3;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
1820
import java.io.File;
1921
import java.io.IOException;
2022
import java.nio.file.Path;
@@ -33,10 +35,13 @@
3335
import software.amazon.awssdk.services.s3.model.DeleteBucketRequest;
3436
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
3537
import software.amazon.awssdk.services.s3.model.NoSuchKeyException;
38+
import software.amazon.awssdk.stability.tests.exceptions.StabilityTestsRetryableException;
39+
import software.amazon.awssdk.stability.tests.utils.RetryableTest;
3640
import software.amazon.awssdk.stability.tests.utils.StabilityTestRunner;
3741
import software.amazon.awssdk.testutils.RandomTempFile;
3842
import software.amazon.awssdk.testutils.service.AwsTestBase;
3943
import software.amazon.awssdk.utils.Logger;
44+
import software.amazon.awssdk.utils.Md5Utils;
4045

4146
public abstract class S3BaseStabilityTest extends AwsTestBase {
4247
private static final Logger log = Logger.loggerFor(S3BaseStabilityTest.class);
@@ -60,6 +65,19 @@ public S3BaseStabilityTest(S3AsyncClient testClient) {
6065
this.testClient = testClient;
6166
}
6267

68+
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
69+
public void largeObject_put_get_usingFile() {
70+
String md5Upload = uploadLargeObjectFromFile();
71+
String md5Download = downloadLargeObjectToFile();
72+
assertThat(md5Upload).isEqualTo(md5Download);
73+
}
74+
75+
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
76+
public void putObject_getObject_highConcurrency() {
77+
putObject();
78+
getObject();
79+
}
80+
6381
protected String computeKeyName(int i) {
6482
return "key_" + i;
6583
}
@@ -79,25 +97,35 @@ protected void doGetBucketAcl_lowTpsLongInterval() {
7997
}
8098

8199

82-
protected void downloadLargeObjectToFile() {
100+
protected String downloadLargeObjectToFile() {
83101
File randomTempFile = RandomTempFile.randomUncreatedFile();
84102
StabilityTestRunner.newRunner()
85103
.testName("S3AsyncStabilityTest.downloadLargeObjectToFile")
86104
.futures(testClient.getObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME),
87105
AsyncResponseTransformer.toFile(randomTempFile)))
88106
.run();
89-
randomTempFile.delete();
107+
108+
109+
try {
110+
return Md5Utils.md5AsBase64(randomTempFile);
111+
} catch (IOException e) {
112+
throw new RuntimeException(e);
113+
} finally {
114+
randomTempFile.delete();
115+
}
90116
}
91117

92-
protected void uploadLargeObjectFromFile() {
118+
protected String uploadLargeObjectFromFile() {
93119
RandomTempFile file = null;
94120
try {
95121
file = new RandomTempFile((long) 2e+9);
122+
String md5 = Md5Utils.md5AsBase64(file);
96123
StabilityTestRunner.newRunner()
97124
.testName("S3AsyncStabilityTest.uploadLargeObjectFromFile")
98125
.futures(testClient.putObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME),
99126
AsyncRequestBody.fromFile(file)))
100127
.run();
128+
return md5;
101129
} catch (IOException e) {
102130
throw new RuntimeException("fail to create test file", e);
103131
} finally {

test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3CrtAsyncClientStabilityTest.java

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717

1818
import org.junit.jupiter.api.AfterAll;
1919
import org.junit.jupiter.api.BeforeAll;
20-
import software.amazon.awssdk.transfer.s3.internal.S3CrtAsyncClient;
2120
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
2221
import software.amazon.awssdk.regions.Region;
2322
import software.amazon.awssdk.services.s3.S3AsyncClient;
24-
import software.amazon.awssdk.stability.tests.exceptions.StabilityTestsRetryableException;
25-
import software.amazon.awssdk.stability.tests.utils.RetryableTest;
23+
import software.amazon.awssdk.transfer.s3.internal.S3CrtAsyncClient;
2624

2725
/**
2826
* Stability tests for {@link S3CrtAsyncClient}
@@ -64,16 +62,4 @@ public static void cleanup() {
6462
protected String getTestBucketName() {
6563
return BUCKET_NAME;
6664
}
67-
68-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
69-
public void largeObject_put_get_usingFile() {
70-
uploadLargeObjectFromFile();
71-
downloadLargeObjectToFile();
72-
}
73-
74-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
75-
public void putObject_getObject_highConcurrency() {
76-
putObject();
77-
getObject();
78-
}
7965
}

test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3NettyAsyncStabilityTest.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ public static void cleanup() {
4545
@Override
4646
protected String getTestBucketName() { return bucketName; }
4747

48-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
49-
public void putObject_getObject_highConcurrency() {
50-
putObject();
51-
getObject();
52-
}
53-
54-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
55-
public void largeObject_put_get_usingFile() {
56-
uploadLargeObjectFromFile();
57-
downloadLargeObjectToFile();
58-
}
59-
6048
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
6149
public void getBucketAcl_lowTpsLongInterval_Netty() {
6250
doGetBucketAcl_lowTpsLongInterval();

test/stability-tests/src/it/java/software/amazon/awssdk/stability/tests/s3/S3WithCrtAsyncHttpClientStabilityTest.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,6 @@ public static void cleanup() {
5151
@Override
5252
protected String getTestBucketName() { return bucketName; }
5353

54-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
55-
public void putObject_getObject_highConcurrency() {
56-
putObject();
57-
getObject();
58-
}
59-
60-
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
61-
public void largeObject_put_get_usingFile() {
62-
uploadLargeObjectFromFile();
63-
downloadLargeObjectToFile();
64-
}
65-
6654
@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
6755
public void getBucketAcl_lowTpsLongInterval_Crt() {
6856
doGetBucketAcl_lowTpsLongInterval();

0 commit comments

Comments
 (0)