Skip to content

Updates StsCredentialsProvider to explicitly handle AwsSessionCredentials #4067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS STS",
"contributor": "",
"description": "Updates the core STS credential provider logic to return AwsSessionCredentials instead of an STS-specific class, and adds expirationTime to AwsSessionCredentials"
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

package software.amazon.awssdk.auth.credentials;

import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import software.amazon.awssdk.annotations.Immutable;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.utils.ToString;
Expand All @@ -34,10 +36,20 @@ public final class AwsSessionCredentials implements AwsCredentials {
private final String secretAccessKey;
private final String sessionToken;

private AwsSessionCredentials(String accessKey, String secretKey, String sessionToken) {
this.accessKeyId = Validate.paramNotNull(accessKey, "accessKey");
this.secretAccessKey = Validate.paramNotNull(secretKey, "secretKey");
this.sessionToken = Validate.paramNotNull(sessionToken, "sessionToken");
private final Instant expirationTime;

private AwsSessionCredentials(Builder builder) {
this.accessKeyId = Validate.paramNotNull(builder.accessKeyId, "accessKey");
this.secretAccessKey = Validate.paramNotNull(builder.secretAccessKey, "secretKey");
this.sessionToken = Validate.paramNotNull(builder.sessionToken, "sessionToken");
this.expirationTime = builder.expirationTime;
}

/**
* Returns a builder for this object.
*/
public static Builder builder() {
return new Builder();
}

/**
Expand All @@ -49,7 +61,7 @@ private AwsSessionCredentials(String accessKey, String secretKey, String session
* received temporary permission to access some resource.
*/
public static AwsSessionCredentials create(String accessKey, String secretKey, String sessionToken) {
return new AwsSessionCredentials(accessKey, secretKey, sessionToken);
return builder().accessKeyId(accessKey).secretAccessKey(secretKey).sessionToken(sessionToken).build();
}

/**
Expand All @@ -68,6 +80,13 @@ public String secretAccessKey() {
return secretAccessKey;
}

/**
* Retrieve the expiration time of these credentials, if it exists.
*/
public Optional<Instant> expirationTime() {
return Optional.ofNullable(expirationTime);
}

/**
* Retrieve the AWS session token. This token is retrieved from an AWS token service, and is used for authenticating that this
* user has received temporary permission to access some resource.
Expand Down Expand Up @@ -95,7 +114,8 @@ public boolean equals(Object o) {
AwsSessionCredentials that = (AwsSessionCredentials) o;
return Objects.equals(accessKeyId, that.accessKeyId) &&
Objects.equals(secretAccessKey, that.secretAccessKey) &&
Objects.equals(sessionToken, that.sessionToken);
Objects.equals(sessionToken, that.sessionToken) &&
Objects.equals(expirationTime, that.expirationTime().orElse(null));
}

@Override
Expand All @@ -104,6 +124,57 @@ public int hashCode() {
hashCode = 31 * hashCode + Objects.hashCode(accessKeyId());
hashCode = 31 * hashCode + Objects.hashCode(secretAccessKey());
hashCode = 31 * hashCode + Objects.hashCode(sessionToken());
hashCode = 31 * hashCode + Objects.hashCode(expirationTime);
return hashCode;
}

/**
* A builder for creating an instance of {@link AwsSessionCredentials}. This can be created with the static
* {@link #builder()} method.
*/
public static final class Builder {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we javadoc this?

private String accessKeyId;
private String secretAccessKey;
private String sessionToken;
private Instant expirationTime;

/**
* The AWS access key, used to identify the user interacting with services. Required.
*/
public Builder accessKeyId(String accessKeyId) {
this.accessKeyId = accessKeyId;
return this;
}

/**
* The AWS secret access key, used to authenticate the user interacting with services. Required
*/
public Builder secretAccessKey(String secretAccessKey) {
this.secretAccessKey = secretAccessKey;
return this;
}

/**
* The AWS session token, retrieved from an AWS token service, used for authenticating that this user has
* received temporary permission to access some resource. Required
*/
public Builder sessionToken(String sessionToken) {
this.sessionToken = sessionToken;
return this;
}

/**
* The time after which this identity will no longer be valid. If this is empty,
* an expiration time is not known (but the identity may still expire at some
* time in the future).
*/
public Builder expirationTime(Instant expirationTime) {
this.expirationTime = expirationTime;
return this;
}

public AwsSessionCredentials build() {
return new AwsSessionCredentials(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,64 @@

package software.amazon.awssdk.auth.credentials.internal;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import static org.assertj.core.api.Assertions.assertThat;

import nl.jqno.equalsverifier.EqualsVerifier;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;

public class AwsSessionCredentialsTest {

private static final String ACCESS_KEY_ID = "accessKeyId";
private static final String SECRET_ACCESS_KEY = "secretAccessKey";
private static final String SESSION_TOKEN = "sessionToken";

public void equalsHashcode() {
EqualsVerifier.forClass(AwsSessionCredentials.class)
.verify();
}

@Test
public void emptyBuilder_ThrowsException() {
assertThrows(NullPointerException.class, () -> AwsSessionCredentials.builder().build());
}

@Test
public void builderMissingSessionToken_ThrowsException() {
assertThrows(NullPointerException.class, () -> AwsSessionCredentials.builder()
.accessKeyId(ACCESS_KEY_ID)
.secretAccessKey(SECRET_ACCESS_KEY)
.build());
}

@Test
public void equalsHashCode() {
AwsSessionCredentials credentials =
AwsSessionCredentials.create("test", "key", "sessionToken");

AwsSessionCredentials anotherCredentials =
AwsSessionCredentials.create("test", "key", "sessionToken");
assertThat(credentials).isEqualTo(anotherCredentials);
assertThat(credentials.hashCode()).isEqualTo(anotherCredentials.hashCode());
public void builderMissingAccessKeyId_ThrowsException() {
assertThrows(NullPointerException.class, () -> AwsSessionCredentials.builder()
.secretAccessKey(SECRET_ACCESS_KEY)
.sessionToken(SESSION_TOKEN)
.build());
}

@Test
public void create_isSuccessful() {
AwsSessionCredentials identity = AwsSessionCredentials.create(ACCESS_KEY_ID,
SECRET_ACCESS_KEY,
SESSION_TOKEN);
assertEquals(ACCESS_KEY_ID, identity.accessKeyId());
assertEquals(SECRET_ACCESS_KEY, identity.secretAccessKey());
assertEquals(SESSION_TOKEN, identity.sessionToken());
}

@Test
public void build_isSuccessful() {
AwsSessionCredentials identity = AwsSessionCredentials.builder()
.accessKeyId(ACCESS_KEY_ID)
.secretAccessKey(SECRET_ACCESS_KEY)
.sessionToken(SESSION_TOKEN)
.build();
assertEquals(ACCESS_KEY_ID, identity.accessKeyId());
assertEquals(SECRET_ACCESS_KEY, identity.secretAccessKey());
assertEquals(SESSION_TOKEN, identity.sessionToken());
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

package software.amazon.awssdk.services.sts.auth;

import static software.amazon.awssdk.services.sts.internal.StsAuthUtils.toAwsSessionCredentials;

import java.util.function.Consumer;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.annotations.ThreadSafe;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.utils.ToString;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.builder.ToCopyableBuilder;
Expand Down Expand Up @@ -65,10 +67,10 @@ public static Builder builder() {
}

@Override
protected Credentials getUpdatedCredentials(StsClient stsClient) {
protected AwsSessionCredentials getUpdatedCredentials(StsClient stsClient) {
AssumeRoleRequest assumeRoleRequest = assumeRoleRequestSupplier.get();
Validate.notNull(assumeRoleRequest, "Assume role request must not be null.");
return stsClient.assumeRole(assumeRoleRequest).credentials();
return toAwsSessionCredentials(stsClient.assumeRole(assumeRoleRequest).credentials());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

package software.amazon.awssdk.services.sts.auth;

import static software.amazon.awssdk.services.sts.internal.StsAuthUtils.toAwsSessionCredentials;

import java.util.function.Consumer;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.annotations.ThreadSafe;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithSamlRequest;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.utils.ToString;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.builder.ToCopyableBuilder;
Expand Down Expand Up @@ -66,10 +68,10 @@ public static Builder builder() {
}

@Override
protected Credentials getUpdatedCredentials(StsClient stsClient) {
protected AwsSessionCredentials getUpdatedCredentials(StsClient stsClient) {
AssumeRoleWithSamlRequest assumeRoleWithSamlRequest = assumeRoleWithSamlRequestSupplier.get();
Validate.notNull(assumeRoleWithSamlRequest, "Assume role with saml request must not be null.");
return stsClient.assumeRoleWithSAML(assumeRoleWithSamlRequest).credentials();
return toAwsSessionCredentials(stsClient.assumeRoleWithSAML(assumeRoleWithSamlRequest).credentials());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.services.sts.auth;

import static software.amazon.awssdk.services.sts.internal.StsAuthUtils.toAwsSessionCredentials;
import static software.amazon.awssdk.utils.Validate.notNull;

import java.util.function.Consumer;
Expand All @@ -23,9 +24,9 @@
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.annotations.ThreadSafe;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityRequest;
import software.amazon.awssdk.services.sts.model.Credentials;
import software.amazon.awssdk.utils.ToString;
import software.amazon.awssdk.utils.builder.ToCopyableBuilder;

Expand Down Expand Up @@ -67,10 +68,10 @@ public static Builder builder() {
}

@Override
protected Credentials getUpdatedCredentials(StsClient stsClient) {
protected AwsSessionCredentials getUpdatedCredentials(StsClient stsClient) {
AssumeRoleWithWebIdentityRequest request = assumeRoleWithWebIdentityRequest.get();
notNull(request, "AssumeRoleWithWebIdentityRequest can't be null");
return stsClient.assumeRoleWithWebIdentity(request).credentials();
return toAwsSessionCredentials(stsClient.assumeRoleWithWebIdentity(request).credentials());
}

@Override
Expand Down
Loading