Skip to content

Adds account ID support for profile credentials provider sources #4340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public final class ProcessCredentialsProvider
private final List<String> executableCommand;
private final Duration credentialRefreshThreshold;
private final long processOutputLimit;
private final String staticAccountId;

private final CachedSupplier<AwsCredentials> processCredentialCache;

Expand Down Expand Up @@ -101,6 +102,7 @@ private ProcessCredentialsProvider(Builder builder) {
this.credentialRefreshThreshold = Validate.isPositive(builder.credentialRefreshThreshold, "expirationBuffer");
this.commandFromBuilder = builder.command;
this.asyncCredentialUpdateEnabled = builder.asyncCredentialUpdateEnabled;
this.staticAccountId = builder.staticAccountId;

CachedSupplier.Builder<AwsCredentials> cacheBuilder = CachedSupplier.builder(this::refreshCredentials)
.cachedValueName(toString());
Expand Down Expand Up @@ -170,19 +172,21 @@ private AwsCredentials credentials(JsonNode credentialsJson) {
Validate.notEmpty(accessKeyId, "AccessKeyId cannot be empty.");
Validate.notEmpty(secretAccessKey, "SecretAccessKey cannot be empty.");

String resolvedAccountId = accountId == null ? this.staticAccountId : accountId;

if (sessionToken != null) {
return AwsSessionCredentials.builder()
.accessKeyId(accessKeyId)
.secretAccessKey(secretAccessKey)
.sessionToken(sessionToken)
.expirationTime(credentialExpirationTime(credentialsJson))
.accountId(accountId)
.accountId(resolvedAccountId)
.build();
}
return AwsBasicCredentials.builder()
.accessKeyId(accessKeyId)
.secretAccessKey(secretAccessKey)
.accountId(accountId)
.accountId(resolvedAccountId)
.build();
}

Expand Down Expand Up @@ -247,6 +251,7 @@ public static class Builder implements CopyableBuilder<Builder, ProcessCredentia
private String command;
private Duration credentialRefreshThreshold = Duration.ofSeconds(15);
private long processOutputLimit = 64000;
private String staticAccountId;

/**
* @see #builder()
Expand All @@ -259,6 +264,7 @@ private Builder(ProcessCredentialsProvider provider) {
this.command = provider.commandFromBuilder;
this.credentialRefreshThreshold = provider.credentialRefreshThreshold;
this.processOutputLimit = provider.processOutputLimit;
this.staticAccountId = provider.staticAccountId;
}

/**
Expand Down Expand Up @@ -304,6 +310,19 @@ public Builder processOutputLimit(long outputByteLimit) {
return this;
}

/**
* Configure a static account id for this credentials provider. Account id for ProcessCredentialsProvider is only
* relevant in a context where a service constructs endpoint URL containing an account id.
* This option should ONLY be used if the provider should return credentials with account id, and the process does not
* output account id. If a static account ID is configured, and the process also returns an account
* id, the process output value overrides the static value. If used, the static account id MUST match the credentials
* returned by the process.
*/
public Builder staticAccountId(String staticAccountId) {
this.staticAccountId = staticAccountId;
return this;
}

public ProcessCredentialsProvider build() {
return new ProcessCredentialsProvider(this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,11 @@ private Optional<AwsCredentialsProvider> credentialsProvider(Set<String> childre
private AwsCredentialsProvider basicProfileCredentialsProvider() {
requireProperties(ProfileProperty.AWS_ACCESS_KEY_ID,
ProfileProperty.AWS_SECRET_ACCESS_KEY);
AwsCredentials credentials = AwsBasicCredentials.create(properties.get(ProfileProperty.AWS_ACCESS_KEY_ID),
properties.get(ProfileProperty.AWS_SECRET_ACCESS_KEY));
AwsCredentials credentials = AwsBasicCredentials.builder()
.accessKeyId(properties.get(ProfileProperty.AWS_ACCESS_KEY_ID))
.secretAccessKey(properties.get(ProfileProperty.AWS_SECRET_ACCESS_KEY))
.accountId(properties.get(ProfileProperty.AWS_ACCOUNT_ID))
.build();
return StaticCredentialsProvider.create(credentials);
}

Expand All @@ -169,9 +172,12 @@ private AwsCredentialsProvider sessionProfileCredentialsProvider() {
requireProperties(ProfileProperty.AWS_ACCESS_KEY_ID,
ProfileProperty.AWS_SECRET_ACCESS_KEY,
ProfileProperty.AWS_SESSION_TOKEN);
AwsCredentials credentials = AwsSessionCredentials.create(properties.get(ProfileProperty.AWS_ACCESS_KEY_ID),
properties.get(ProfileProperty.AWS_SECRET_ACCESS_KEY),
properties.get(ProfileProperty.AWS_SESSION_TOKEN));
AwsCredentials credentials = AwsSessionCredentials.builder()
.accessKeyId(properties.get(ProfileProperty.AWS_ACCESS_KEY_ID))
.secretAccessKey(properties.get(ProfileProperty.AWS_SECRET_ACCESS_KEY))
.sessionToken(properties.get(ProfileProperty.AWS_SESSION_TOKEN))
.accountId(properties.get(ProfileProperty.AWS_ACCOUNT_ID))
.build();
return StaticCredentialsProvider.create(credentials);
}

Expand All @@ -180,6 +186,7 @@ private AwsCredentialsProvider credentialProcessCredentialsProvider() {

return ProcessCredentialsProvider.builder()
.command(properties.get(ProfileProperty.CREDENTIAL_PROCESS))
.staticAccountId(properties.get(ProfileProperty.AWS_ACCOUNT_ID))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.utils.DateUtils;
import software.amazon.awssdk.utils.IoUtils;
import software.amazon.awssdk.utils.Platform;
Expand Down Expand Up @@ -64,36 +69,40 @@ static void teardown() {
throw new IllegalStateException("Failed to delete file: " + errorScriptLocation);
}
}

@Test
void staticCredentialsCanBeLoaded() {
AwsCredentials credentials =
ProcessCredentialsProvider.builder()
.command(String.format("%s accessKeyId secretAccessKey", scriptLocation))
.build()
.resolveCredentials();

assertThat(credentials).isNotInstanceOf(AwsSessionCredentials.class);
assertThat(credentials.accessKeyId()).isEqualTo(ACCESS_KEY_ID);
assertThat(credentials.secretAccessKey()).isEqualTo(SECRET_ACCESS_KEY);
assertThat(credentials.accountId()).isNotPresent();
}

@Test
void staticCredentialsWithAccountIdCanBeLoaded() {
AwsCredentials credentials =
ProcessCredentialsProvider.builder()
.command(String.format("%s %s %s acctid=%s",
scriptLocation, ACCESS_KEY_ID, SECRET_ACCESS_KEY, ACCOUNT_ID))
.build()
.resolveCredentials();
@ParameterizedTest(name = "{index} - {0}")
@MethodSource("staticCredentialsValues")
void staticCredentialsCanBeLoaded(String description, String staticAccountId, Optional<String> expectedValue,
String cmd) {
ProcessCredentialsProvider.Builder providerBuilder = ProcessCredentialsProvider.builder().command(cmd);
if (staticAccountId != null) {
providerBuilder.staticAccountId(staticAccountId);
}
AwsCredentials credentials = providerBuilder.build().resolveCredentials();

verifyCredentials(credentials);
assertThat(credentials).isNotInstanceOf(AwsSessionCredentials.class);
assertThat(credentials.accessKeyId()).isEqualTo(ACCESS_KEY_ID);
assertThat(credentials.secretAccessKey()).isEqualTo(SECRET_ACCESS_KEY);
assertThat(credentials.accountId()).isPresent().isEqualTo(Optional.of(ACCOUNT_ID));

if (expectedValue.isPresent()) {
assertThat(credentials.accountId()).isPresent().hasValue(expectedValue.get());
} else {
assertThat(credentials.accountId()).isNotPresent();
}
}


private static List<Arguments> staticCredentialsValues() {
return Arrays.asList(
Arguments.of("when only containing access key id, secret", null, Optional.empty(),
String.format("%s accessKeyId secretAccessKey", scriptLocation)),
Arguments.of("when output has account id", null, Optional.of(ACCOUNT_ID),
String.format("%s %s %s acctid=%s", scriptLocation, ACCESS_KEY_ID, SECRET_ACCESS_KEY, ACCOUNT_ID)),
Arguments.of("when output has account id, static account id configured", "staticAccountId", Optional.of(ACCOUNT_ID),
String.format("%s %s %s acctid=%s", scriptLocation, ACCESS_KEY_ID, SECRET_ACCESS_KEY, ACCOUNT_ID)),
Arguments.of("when only static account id is configured", "staticAccountId", Optional.of("staticAccountId"),
String.format("%s %s %s", scriptLocation, ACCESS_KEY_ID, SECRET_ACCESS_KEY))
);
}

@Test
void sessionCredentialsCanBeLoaded() {
String expiration = DateUtils.formatIso8601Date(Instant.now());
Expand Down Expand Up @@ -122,21 +131,42 @@ void sessionCredentialsWithAccountIdCanBeLoaded() {

AwsCredentials credentials = credentialsProvider.resolveCredentials();
verifySessionCredentials(credentials, expiration);
assertThat(credentials.accountId()).isPresent().isEqualTo(Optional.of(ACCOUNT_ID));
assertThat(credentials.accountId()).isPresent().hasValue(ACCOUNT_ID);
}

@Test
void sessionCredentialsWithStaticAccountIdCanBeLoaded() {
String expiration = DateUtils.formatIso8601Date(Instant.now());
ProcessCredentialsProvider credentialsProvider =
ProcessCredentialsProvider.builder()
.command(String.format("%s %s %s token=sessionToken exp=%s",
scriptLocation, ACCESS_KEY_ID, SECRET_ACCESS_KEY, expiration))
.credentialRefreshThreshold(Duration.ofSeconds(1))
.staticAccountId("staticAccountId")
.build();

AwsCredentials credentials = credentialsProvider.resolveCredentials();
verifySessionCredentials(credentials, expiration);
assertThat(credentials.accountId()).isPresent().hasValue("staticAccountId");
}

private void verifySessionCredentials(AwsCredentials credentials, String expiration) {
verifyCredentials(credentials);

assertThat(credentials).isInstanceOf(AwsSessionCredentials.class);
AwsSessionCredentials sessionCredentials = (AwsSessionCredentials) credentials;

assertThat(sessionCredentials.accessKeyId()).isEqualTo(ACCESS_KEY_ID);
assertThat(sessionCredentials.secretAccessKey()).isEqualTo(SECRET_ACCESS_KEY);
assertThat(sessionCredentials.sessionToken()).isEqualTo(SESSION_TOKEN);
assertThat(sessionCredentials.expirationTime()).isPresent();
Instant exp = sessionCredentials.expirationTime().get();

assertThat(credentials.expirationTime()).isPresent();
Instant exp = credentials.expirationTime().get();
assertThat(exp).isCloseTo(expiration, within(1, ChronoUnit.MICROS));
}

private void verifyCredentials(AwsCredentials credentials) {
assertThat(credentials.accessKeyId()).isEqualTo(ACCESS_KEY_ID);
assertThat(credentials.secretAccessKey()).isEqualTo(SECRET_ACCESS_KEY);
}

@Test
void resultsAreCached() {
ProcessCredentialsProvider credentialsProvider =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.file.FileSystem;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;
import java.util.function.Supplier;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -114,6 +115,24 @@ void presentProfileReturnsCredentials() {
assertThat(provider.resolveCredentials()).satisfies(credentials -> {
assertThat(credentials.accessKeyId()).isEqualTo("defaultAccessKey");
assertThat(credentials.secretAccessKey()).isEqualTo("defaultSecretAccessKey");
assertThat(credentials.accountId()).isNotPresent();
});
}

@Test
void presentProfileWithAccountIdReturnsCredentialsWithAccountId() {
ProfileFile file = profileFile("[default]\n"
+ "aws_access_key_id = defaultAccessKey\n"
+ "aws_secret_access_key = defaultSecretAccessKey\n"
+ "aws_account_id = defaultAccountId");

ProfileCredentialsProvider provider =
ProfileCredentialsProvider.builder().profileFile(file).profileName("default").build();

assertThat(provider.resolveCredentials()).satisfies(credentials -> {
assertThat(credentials.accessKeyId()).isEqualTo("defaultAccessKey");
assertThat(credentials.secretAccessKey()).isEqualTo("defaultSecretAccessKey");
assertThat(credentials.accountId()).isPresent().isEqualTo(Optional.of("defaultAccountId"));
});
}

Expand Down Expand Up @@ -201,13 +220,15 @@ void resolveCredentials_presentProfileFileSupplier_returnsCredentials() {
assertThat(provider.resolveCredentials()).satisfies(credentials -> {
assertThat(credentials.accessKeyId()).isEqualTo("defaultAccessKey");
assertThat(credentials.secretAccessKey()).isEqualTo("defaultSecretAccessKey");
assertThat(credentials.accountId()).isNotPresent();
});
}

@Test
void resolveCredentials_presentSupplierProfileFile_returnsCredentials() {
Supplier<ProfileFile> supplier = () -> profileFile("[default]\naws_access_key_id = defaultAccessKey\n"
+ "aws_secret_access_key = defaultSecretAccessKey\n");
+ "aws_secret_access_key = defaultSecretAccessKey\n"
+ "aws_account_id = defaultAccountId");

ProfileCredentialsProvider provider =
ProfileCredentialsProvider.builder()
Expand All @@ -218,6 +239,7 @@ void resolveCredentials_presentSupplierProfileFile_returnsCredentials() {
assertThat(provider.resolveCredentials()).satisfies(credentials -> {
assertThat(credentials.accessKeyId()).isEqualTo("defaultAccessKey");
assertThat(credentials.secretAccessKey()).isEqualTo("defaultSecretAccessKey");
assertThat(credentials.accountId()).isPresent();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,52 @@

package software.amazon.awssdk.auth.credentials;

import static org.junit.Assert.assertEquals;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.Test;
import org.junit.jupiter.api.Test;

public class StaticCredentialsProviderTest {

@Test
void getAwsCredentials_ReturnsSameCredentials() {
AwsCredentials credentials = AwsBasicCredentials.create("akid", "skid");
AwsCredentials actualCredentials = StaticCredentialsProvider.create(credentials).resolveCredentials();
assertThat(actualCredentials).isEqualTo(credentials);
}

@Test
public void getAwsCredentials_ReturnsSameCredentials() throws Exception {
final AwsCredentials credentials = AwsBasicCredentials.create("akid", "skid");
final AwsCredentials actualCredentials =
StaticCredentialsProvider.create(credentials).resolveCredentials();
assertEquals(credentials, actualCredentials);
void getAwsCredentialsWithAccountId_ReturnsSameCredentials() {
AwsCredentials credentials = AwsBasicCredentials.builder()
.accessKeyId("akid")
.secretAccessKey("skid")
.accountId("acctid")
.build();
AwsCredentials actualCredentials = StaticCredentialsProvider.create(credentials).resolveCredentials();
assertThat(actualCredentials).isEqualTo(credentials);
}

@Test
public void getSessionAwsCredentials_ReturnsSameCredentials() throws Exception {
final AwsSessionCredentials credentials = AwsSessionCredentials.create("akid", "skid", "token");
final AwsCredentials actualCredentials = StaticCredentialsProvider.create(credentials).resolveCredentials();
assertEquals(credentials, actualCredentials);
void getSessionAwsCredentials_ReturnsSameCredentials() {
AwsSessionCredentials credentials = AwsSessionCredentials.create("akid", "skid", "token");
AwsCredentials actualCredentials = StaticCredentialsProvider.create(credentials).resolveCredentials();
assertThat(actualCredentials).isEqualTo(credentials);
}

@Test(expected = RuntimeException.class)
public void nullCredentials_ThrowsIllegalArgumentException() {
StaticCredentialsProvider.create(null);
@Test
void getSessionAwsCredentialsWithAccountId_ReturnsSameCredentials() {
AwsSessionCredentials credentials = AwsSessionCredentials.builder()
.accessKeyId("akid")
.secretAccessKey("skid")
.sessionToken("token")
.accountId("acctid")
.build();
AwsCredentials actualCredentials = StaticCredentialsProvider.create(credentials).resolveCredentials();
assertThat(actualCredentials).isEqualTo(credentials);
}

@Test
void nullCredentials_ThrowsException() {
assertThatThrownBy(() -> StaticCredentialsProvider.create(null)).isInstanceOf(NullPointerException.class);
}
}
Loading