Skip to content

feat: implement BufferedCipherSubscriber to enforce buffered decrypti… #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package software.amazon.encryption.s3.internal;

import org.reactivestreams.Subscriber;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.encryption.s3.legacy.internal.RangedGetUtils;

import javax.crypto.Cipher;
import java.nio.ByteBuffer;

public class BufferedCipherPublisher implements SdkPublisher<ByteBuffer> {

private final SdkPublisher<ByteBuffer> wrappedPublisher;
private final Cipher cipher;
private final Long contentLength;
private final long[] range;
private final String contentRange;
private final int cipherTagLengthBits;

public BufferedCipherPublisher(final Cipher cipher, final SdkPublisher<ByteBuffer> wrappedPublisher, final Long contentLength, long[] range, String contentRange, int cipherTagLengthBits) {
this.wrappedPublisher = wrappedPublisher;
this.cipher = cipher;
this.contentLength = contentLength;
this.range = range;
this.contentRange = contentRange;
this.cipherTagLengthBits = cipherTagLengthBits;
}

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
// Wrap the (customer) subscriber in a CipherSubscriber, then subscribe it
// to the wrapped (ciphertext) publisher
Subscriber<? super ByteBuffer> wrappedSubscriber = RangedGetUtils.adjustToDesiredRange(subscriber, range, contentRange, cipherTagLengthBits);
wrappedPublisher.subscribe(new BufferedCipherSubscriber(wrappedSubscriber, cipher, contentLength));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package software.amazon.encryption.s3.internal;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.encryption.s3.S3EncryptionClientException;
import software.amazon.encryption.s3.S3EncryptionClientSecurityException;

import javax.crypto.Cipher;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

/**
* A subscriber which decrypts data by buffering the object's contents
* so that authentication can be done before any plaintext is released.
* This prevents "release of unauthenticated plaintext" at the cost of
* allocating a large buffer.
*/
public class BufferedCipherSubscriber implements Subscriber<ByteBuffer> {

// 64MiB ought to be enough for most usecases
private static final long BUFFERED_MAX_CONTENT_LENGTH_MiB = 64;
private static final long BUFFERED_MAX_CONTENT_LENGTH_BYTES = 1024 * 1024 * BUFFERED_MAX_CONTENT_LENGTH_MiB;

private final AtomicInteger contentRead = new AtomicInteger(0);
private final AtomicBoolean doneFinal = new AtomicBoolean(false);
private final Subscriber<? super ByteBuffer> wrappedSubscriber;
private final Cipher cipher;
private final int contentLength;

private byte[] outputBuffer;
private final Queue<ByteBuffer> buffers = new ConcurrentLinkedQueue<>();

BufferedCipherSubscriber(Subscriber<? super ByteBuffer> wrappedSubscriber, Cipher cipher, Long contentLength) {
this.wrappedSubscriber = wrappedSubscriber;
this.cipher = cipher;
if (contentLength == null) {
throw new S3EncryptionClientException("contentLength cannot be null in buffered mode. To enable unbounded " +
"streaming, reconfigure the S3 Encryption Client with Delayed Authentication mode enabled.");
}
if (contentLength > BUFFERED_MAX_CONTENT_LENGTH_BYTES) {
throw new S3EncryptionClientException(String.format("The object you are attempting to decrypt exceeds the maximum content " +
"length allowed in default mode. Please enable Delayed Authentication mode to decrypt objects larger" +
"than %d", BUFFERED_MAX_CONTENT_LENGTH_MiB));
}
this.contentLength = Math.toIntExact(contentLength);
}

@Override
public void onSubscribe(Subscription s) {
wrappedSubscriber.onSubscribe(s);
}

@Override
public void onNext(ByteBuffer byteBuffer) {
int amountToReadFromByteBuffer = getAmountToReadFromByteBuffer(byteBuffer);

if (amountToReadFromByteBuffer > 0) {
byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer, amountToReadFromByteBuffer);
outputBuffer = cipher.update(buf, 0, amountToReadFromByteBuffer);

if (outputBuffer == null && amountToReadFromByteBuffer < cipher.getBlockSize()) {
// The underlying data is too short to fill in the block cipher
// This is true at the end of the file, so complete to get the final
// bytes
this.onComplete();
}

// Enqueue the buffer until all data is read
buffers.add(ByteBuffer.wrap(outputBuffer));

// Sometimes, onComplete won't be called, so we check if all
// data is read to avoid hanging indefinitely
if (contentRead.get() == contentLength) {
this.onComplete();
}
// This avoids the subscriber waiting indefinitely for more data
// without actually releasing any plaintext before it can be authenticated
wrappedSubscriber.onNext(ByteBuffer.allocate(0));
}

}

private int getAmountToReadFromByteBuffer(ByteBuffer byteBuffer) {

long amountReadSoFar = contentRead.getAndAdd(byteBuffer.remaining());
long amountRemaining = Math.max(0, contentLength - amountReadSoFar);

if (amountRemaining > byteBuffer.remaining()) {
return byteBuffer.remaining();
} else {
return Math.toIntExact(amountRemaining);
}
}

@Override
public void onError(Throwable t) {
wrappedSubscriber.onError(t);
}

@Override
public void onComplete() {
if (doneFinal.get()) {
// doFinal has already been called, bail out
return;
}
try {
outputBuffer = cipher.doFinal();
doneFinal.set(true);
// Once doFinal is called, then we can release the plaintext
if (contentRead.get() == contentLength) {
while (!buffers.isEmpty()) {
wrappedSubscriber.onNext(buffers.remove());
}
}
// Send the final bytes to the wrapped subscriber
wrappedSubscriber.onNext(ByteBuffer.wrap(outputBuffer));
} catch (final GeneralSecurityException exception) {
// Forward error, else the wrapped subscriber waits indefinitely
wrappedSubscriber.onError(exception);
throw new S3EncryptionClientSecurityException(exception.getMessage(), exception);
}
wrappedSubscriber.onComplete();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,19 @@ public void onStream(SdkPublisher<ByteBuffer> ciphertextPublisher) {
throw new S3EncryptionClientException("Unknown algorithm: " + algorithmSuite.cipherName());
}

CipherPublisher plaintextPublisher = new CipherPublisher(cipher, ciphertextPublisher, getObjectResponse.contentLength(), desiredRange, contentMetadata.contentRange(), algorithmSuite.cipherTagLengthBits());
wrappedAsyncResponseTransformer.onStream(plaintextPublisher);
if (algorithmSuite.equals(AlgorithmSuite.ALG_AES_256_CBC_IV16_NO_KDF) || _enableDelayedAuthentication) {
// CBC and GCM with delayed auth enabled use a standard publisher
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Praise: I love clarifying comments.

CipherPublisher plaintextPublisher = new CipherPublisher(cipher, ciphertextPublisher,
getObjectResponse.contentLength(), desiredRange, contentMetadata.contentRange(), algorithmSuite.cipherTagLengthBits());
wrappedAsyncResponseTransformer.onStream(plaintextPublisher);
} else {
// Use buffered publisher for GCM when delayed auth is not enabled
BufferedCipherPublisher plaintextPublisher = new BufferedCipherPublisher(cipher, ciphertextPublisher,
getObjectResponse.contentLength(), desiredRange, contentMetadata.contentRange(), algorithmSuite.cipherTagLengthBits());
wrappedAsyncResponseTransformer.onStream(plaintextPublisher);

}

} catch (GeneralSecurityException e) {
throw new S3EncryptionClientException("Unable to " + algorithmSuite.cipherName() + " content decrypt.", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -21,6 +22,7 @@
import software.amazon.encryption.s3.utils.MarkResetBoundedZerosInputStream;
import software.amazon.encryption.s3.utils.S3EncryptionClientTestResources;

import javax.crypto.AEADBadTagException;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import java.io.IOException;
Expand All @@ -31,7 +33,10 @@
import java.security.Security;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static software.amazon.encryption.s3.utils.S3EncryptionClientTestResources.appendTestSuffix;
import static software.amazon.encryption.s3.utils.S3EncryptionClientTestResources.deleteObject;

Expand All @@ -55,7 +60,6 @@ public static void setUp() throws NoSuchAlgorithmException {
@Test
public void markResetInputStreamV3Encrypt() throws IOException {
final String objectKey = appendTestSuffix("markResetInputStreamV3Encrypt");

// V3 Client
S3Client v3Client = S3EncryptionClient.builder()
.aesKey(AES_KEY)
Expand Down Expand Up @@ -196,8 +200,7 @@ public void ordinaryInputStreamV3DecryptCbc() throws IOException {
v3Client.close();
}

// TODO : Add Delayed Authentication to Async Client.
//@Test
@Test
public void delayedAuthModeWithLargeObject() throws IOException {
final String objectKey = appendTestSuffix("large-object-test");

Expand All @@ -219,12 +222,11 @@ public void delayedAuthModeWithLargeObject() throws IOException {
.build(), RequestBody.fromInputStream(largeObjectStream, fileSizeExceedingDefaultLimit));

largeObjectStream.close();
// TODO : Add Delayed Authentication to Async Client.

// // Delayed Authentication is not enabled, so getObject fails
// assertThrows(S3EncryptionClientException.class, () -> v3Client.getObjectAsBytes(builder -> builder
// .bucket(BUCKET)
// .key(objectKey)));
// Delayed Authentication is not enabled, so getObject fails
assertThrows(S3EncryptionClientException.class, () -> v3Client.getObjectAsBytes(builder -> builder
.bucket(BUCKET)
.key(objectKey)));

S3Client v3ClientWithDelayedAuth = S3EncryptionClient.builder()
.aesKey(AES_KEY)
Expand Down Expand Up @@ -263,4 +265,78 @@ public void delayedAuthModeWithLargerThanMaxObjectFails() throws IOException {
// Cleanup
v3Client.close();
}

@Test
public void AesGcmV3toV3StreamWithTamperedTag() {
final String objectKey = "aes-gcm-v3-to-v3-stream-tamper-tag";

// V3 Client
S3Client v3Client = S3EncryptionClient.builder()
.aesKey(AES_KEY)
.build();

// 640 bytes of gibberish - enough to cover multiple blocks
final String input = "1esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "2esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "3esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "4esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "5esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "6esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "7esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "8esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "9esAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo"
+ "10sAAAYoAesAAAEndOfChunkAesAAAYoAesAAAYoAesAAAYoAesAAAYoAesAAAYo";
final int inputLength = input.length();
v3Client.putObject(PutObjectRequest.builder()
.bucket(BUCKET)
.key(objectKey)
.build(), RequestBody.fromString(input));

// Use an unencrypted (plaintext) client to interact with the encrypted object
final S3Client plaintextS3Client = S3Client.builder().build();
ResponseBytes<GetObjectResponse> objectResponse = plaintextS3Client.getObjectAsBytes(builder -> builder
.bucket(BUCKET)
.key(objectKey));
final byte[] encryptedBytes = objectResponse.asByteArray();
final int tagLength = 16;
final byte[] tamperedBytes = new byte[inputLength + tagLength];
// Copy the enciphered bytes
System.arraycopy(encryptedBytes, 0, tamperedBytes, 0, inputLength);
final byte[] tamperedTag = new byte[tagLength];
// Increment the first byte of the tag
tamperedTag[0] = (byte) (encryptedBytes[inputLength + 1] + 1);
// Copy the rest of the tag as-is
System.arraycopy(encryptedBytes, inputLength + 1, tamperedTag, 1, tagLength - 1);
// Append the tampered tag
System.arraycopy(tamperedTag, 0, tamperedBytes, inputLength, tagLength);

// Sanity check that the objects differ
assertNotEquals(encryptedBytes, tamperedBytes);

// Replace the encrypted object with the tampered object
PutObjectRequest tamperedPut = PutObjectRequest.builder()
.bucket(BUCKET)
.key(objectKey)
.metadata(objectResponse.response().metadata()) // Preserve metadata from encrypted object
.build();
plaintextS3Client.putObject(tamperedPut, RequestBody.fromBytes(tamperedBytes));

// Get (and decrypt) the (modified) object from S3
ResponseInputStream<GetObjectResponse> dataStream = v3Client.getObject(builder -> builder
.bucket(BUCKET)
.key(objectKey));

final int chunkSize = 300;
final byte[] chunk1 = new byte[chunkSize];

// Stream decryption will throw an exception on the first byte read
try {
dataStream.read(chunk1, 0, chunkSize);
} catch (RuntimeException outerEx) {
assertTrue(outerEx.getCause() instanceof AEADBadTagException);
} catch (IOException unexpected) {
// Not expected, but fail the test anyway
fail(unexpected);
}
}
}