Skip to content

Fix a race condition in FileAsyncRequestBody #2536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.internal.util.Mimetype;
import software.amazon.awssdk.core.internal.util.NoopSubscription;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.builder.SdkBuilder;

/**
Expand All @@ -41,6 +42,7 @@
*/
@SdkInternalApi
public final class FileAsyncRequestBody implements AsyncRequestBody {
private static final Logger log = Logger.loggerFor(FileAsyncRequestBody.class);

/**
* Default size (in bytes) of ByteBuffer chunks read from the file and delivered to the subscriber.
Expand Down Expand Up @@ -169,10 +171,11 @@ private static final class FileSubscription implements Subscription {
private final Subscriber<? super ByteBuffer> subscriber;
private final int chunkSize;

private long position = 0;
private AtomicLong outstandingDemand = new AtomicLong(0);
private boolean writeInProgress = false;
private final AtomicLong position = new AtomicLong(0);
private long outstandingDemand = 0;
private boolean readInProgress = false;
private volatile boolean done = false;
private final Object lock = new Object();

private FileSubscription(AsynchronousFileChannel inputChannel, Subscriber<? super ByteBuffer> subscriber, int chunkSize) {
this.inputChannel = inputChannel;
Expand All @@ -189,23 +192,24 @@ public void request(long n) {
if (n < 1) {
IllegalArgumentException ex =
new IllegalArgumentException(subscriber + " violated the Reactive Streams rule 3.9 by requesting a "
+ "non-positive number of elements.");
+ "non-positive number of elements.");
signalOnError(ex);
} else {
try {
// As governed by rule 3.17, when demand overflows `Long.MAX_VALUE` we treat the signalled demand as
// "effectively unbounded"
outstandingDemand.getAndUpdate(initialDemand -> {
if (Long.MAX_VALUE - initialDemand < n) {
return Long.MAX_VALUE;
// We need to synchronize here because of the race condition
// where readData finishes reading at the same time request
// demand comes in
synchronized (lock) {
// As governed by rule 3.17, when demand overflows `Long.MAX_VALUE` we treat the signalled demand as
// "effectively unbounded"
if (Long.MAX_VALUE - outstandingDemand < n) {
outstandingDemand = Long.MAX_VALUE;
} else {
return initialDemand + n;
outstandingDemand += n;
}
});

synchronized (this) {
if (!writeInProgress) {
writeInProgress = true;
if (!readInProgress) {
readInProgress = true;
readData();
}
}
Expand All @@ -227,32 +231,33 @@ public void cancel() {

private void readData() {
// It's possible to have another request for data come in after we've closed the file.
if (!inputChannel.isOpen()) {
if (!inputChannel.isOpen() || done) {
return;
}

ByteBuffer buffer = ByteBuffer.allocate(chunkSize);
inputChannel.read(buffer, position, buffer, new CompletionHandler<Integer, ByteBuffer>() {
inputChannel.read(buffer, position.get(), buffer, new CompletionHandler<Integer, ByteBuffer>() {
@Override
public void completed(Integer result, ByteBuffer attachment) {

if (result > 0) {
attachment.flip();
position += attachment.remaining();
position.addAndGet(attachment.remaining());
signalOnNext(attachment);
// If we have more permits, queue up another read.
if (outstandingDemand.decrementAndGet() > 0) {
readData();
return;

synchronized (lock) {
// If we have more permits, queue up another read.
if (--outstandingDemand > 0) {
readData();
} else {
readInProgress = false;
}
}
} else {
// Reached the end of the file, notify the subscriber and cleanup
signalOnComplete();
closeFile();
}

synchronized (FileSubscription.this) {
writeInProgress = false;
}
}

@Override
Expand All @@ -267,32 +272,32 @@ private void closeFile() {
try {
inputChannel.close();
} catch (IOException e) {
signalOnError(e);
log.warn(() -> "Failed to close the file", e);
}
}

private void signalOnNext(ByteBuffer bb) {
private void signalOnNext(ByteBuffer attachment) {
synchronized (this) {
if (!done) {
subscriber.onNext(bb);
subscriber.onNext(attachment);
}
}
}

private void signalOnComplete() {
synchronized (this) {
if (!done) {
subscriber.onComplete();
done = true;
subscriber.onComplete();
}
}
}

private void signalOnError(Throwable t) {
synchronized (this) {
if (!done) {
subscriber.onError(t);
done = true;
subscriber.onError(t);
}
}
}
Expand All @@ -301,4 +306,4 @@ private void signalOnError(Throwable t) {
private static AsynchronousFileChannel openInputChannel(Path path) throws IOException {
return AsynchronousFileChannel.open(path, StandardOpenOption.READ);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package software.amazon.awssdk.stability.tests.s3;

import static org.assertj.core.api.Assertions.assertThat;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -33,10 +35,13 @@
import software.amazon.awssdk.services.s3.model.DeleteBucketRequest;
import software.amazon.awssdk.services.s3.model.NoSuchBucketException;
import software.amazon.awssdk.services.s3.model.NoSuchKeyException;
import software.amazon.awssdk.stability.tests.exceptions.StabilityTestsRetryableException;
import software.amazon.awssdk.stability.tests.utils.RetryableTest;
import software.amazon.awssdk.stability.tests.utils.StabilityTestRunner;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.testutils.service.AwsTestBase;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Md5Utils;

public abstract class S3BaseStabilityTest extends AwsTestBase {
private static final Logger log = Logger.loggerFor(S3BaseStabilityTest.class);
Expand All @@ -60,6 +65,19 @@ public S3BaseStabilityTest(S3AsyncClient testClient) {
this.testClient = testClient;
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void largeObject_put_get_usingFile() {
String md5Upload = uploadLargeObjectFromFile();
String md5Download = downloadLargeObjectToFile();
assertThat(md5Upload).isEqualTo(md5Download);
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void putObject_getObject_highConcurrency() {
putObject();
getObject();
}

protected String computeKeyName(int i) {
return "key_" + i;
}
Expand All @@ -79,25 +97,35 @@ protected void doGetBucketAcl_lowTpsLongInterval() {
}


protected void downloadLargeObjectToFile() {
protected String downloadLargeObjectToFile() {
File randomTempFile = RandomTempFile.randomUncreatedFile();
StabilityTestRunner.newRunner()
.testName("S3AsyncStabilityTest.downloadLargeObjectToFile")
.futures(testClient.getObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME),
AsyncResponseTransformer.toFile(randomTempFile)))
.run();
randomTempFile.delete();


try {
return Md5Utils.md5AsBase64(randomTempFile);
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
randomTempFile.delete();
}
}

protected void uploadLargeObjectFromFile() {
protected String uploadLargeObjectFromFile() {
RandomTempFile file = null;
try {
file = new RandomTempFile((long) 2e+9);
String md5 = Md5Utils.md5AsBase64(file);
StabilityTestRunner.newRunner()
.testName("S3AsyncStabilityTest.uploadLargeObjectFromFile")
.futures(testClient.putObject(b -> b.bucket(getTestBucketName()).key(LARGE_KEY_NAME),
AsyncRequestBody.fromFile(file)))
.run();
return md5;
} catch (IOException e) {
throw new RuntimeException("fail to create test file", e);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import software.amazon.awssdk.transfer.s3.internal.S3CrtAsyncClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.stability.tests.exceptions.StabilityTestsRetryableException;
import software.amazon.awssdk.stability.tests.utils.RetryableTest;
import software.amazon.awssdk.transfer.s3.internal.S3CrtAsyncClient;

/**
* Stability tests for {@link S3CrtAsyncClient}
Expand Down Expand Up @@ -64,16 +62,4 @@ public static void cleanup() {
protected String getTestBucketName() {
return BUCKET_NAME;
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void largeObject_put_get_usingFile() {
uploadLargeObjectFromFile();
downloadLargeObjectToFile();
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void putObject_getObject_highConcurrency() {
putObject();
getObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ public static void cleanup() {
@Override
protected String getTestBucketName() { return bucketName; }

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void putObject_getObject_highConcurrency() {
putObject();
getObject();
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void largeObject_put_get_usingFile() {
uploadLargeObjectFromFile();
downloadLargeObjectToFile();
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void getBucketAcl_lowTpsLongInterval_Netty() {
doGetBucketAcl_lowTpsLongInterval();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ public static void cleanup() {
@Override
protected String getTestBucketName() { return bucketName; }

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void putObject_getObject_highConcurrency() {
putObject();
getObject();
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void largeObject_put_get_usingFile() {
uploadLargeObjectFromFile();
downloadLargeObjectToFile();
}

@RetryableTest(maxRetries = 3, retryableException = StabilityTestsRetryableException.class)
public void getBucketAcl_lowTpsLongInterval_Crt() {
doGetBucketAcl_lowTpsLongInterval();
Expand Down