Skip to content

Commit b0a3c24

Browse files
authored
Adds Policy and Duration parameters to stscreds.WebIdentityRoleOptions (#1670)
Adds the Policy and Duration parameters from sts.AssumeRoleWithWebIdentityInput to stscreds.WebIdentityRoleOptions. Closes #1662
1 parent da333d3 commit b0a3c24

File tree

4 files changed

+120
-25
lines changed

4 files changed

+120
-25
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "1be705cb-9be9-4061-bd94-a4dbded36da5",
3+
"type": "feature",
4+
"description": "Adds Duration and Policy options that can be used when creating stscreds.WebIdentityRoleProvider credentials provider.",
5+
"modules": [
6+
"credentials"
7+
]
8+
}

credentials/stscreds/assume_role_provider.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,13 @@ type AssumeRoleAPIClient interface {
136136
AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error)
137137
}
138138

139-
// DefaultDuration is the default amount of time in minutes that the credentials
140-
// will be valid for.
139+
// DefaultDuration is the default amount of time in minutes that the
140+
// credentials will be valid for. This value is only used by AssumeRoleProvider
141+
// for specifying the default expiry duration of an assume role.
142+
//
143+
// Other providers such as WebIdentityRoleProvider do not use this value, and
144+
// instead rely on STS API's default parameter handing to assign a default
145+
// value.
141146
var DefaultDuration = time.Duration(15) * time.Minute
142147

143148
// AssumeRoleProvider retrieves temporary credentials from the STS service, and

credentials/stscreds/web_identity_provider.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io/ioutil"
77
"strconv"
8+
"time"
89

910
"github.com/aws/aws-sdk-go-v2/aws"
1011
"github.com/aws/aws-sdk-go-v2/aws/retry"
@@ -45,6 +46,19 @@ type WebIdentityRoleOptions struct {
4546
// Session name, if you wish to uniquely identify this session.
4647
RoleSessionName string
4748

49+
// Expiry duration of the STS credentials. STS will assign a default expiry
50+
// duration if this value is unset. This is different from the Duration
51+
// option of AssumeRoleProvider, which automatically assigns 15 minutes if
52+
// Duration is unset.
53+
//
54+
// See the STS AssumeRoleWithWebIdentity API reference guide for more
55+
// information on defaults.
56+
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
57+
Duration time.Duration
58+
59+
// An IAM policy in JSON format that you want to use as an inline session policy.
60+
Policy *string
61+
4862
// The Amazon Resource Names (ARNs) of the IAM managed policies that you
4963
// want to use as managed session policies. The policies must exist in the
5064
// same account as the role.
@@ -100,12 +114,21 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
100114
// uses unix time in nanoseconds to uniquely identify sessions.
101115
sessionName = strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
102116
}
103-
resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{
117+
input := &sts.AssumeRoleWithWebIdentityInput{
104118
PolicyArns: p.options.PolicyARNs,
105119
RoleArn: &p.options.RoleARN,
106120
RoleSessionName: &sessionName,
107121
WebIdentityToken: aws.String(string(b)),
108-
}, func(options *sts.Options) {
122+
}
123+
if p.options.Duration != 0 {
124+
// If set use the value, otherwise STS will assign a default expiration duration.
125+
input.DurationSeconds = aws.Int32(int32(p.options.Duration / time.Second))
126+
}
127+
if p.options.Policy != nil {
128+
input.Policy = p.options.Policy
129+
}
130+
131+
resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, input, func(options *sts.Options) {
109132
options.Retryer = retry.AddWithErrorCodes(options.Retryer, invalidIdentityTokenExceptionCode)
110133
})
111134
if err != nil {

credentials/stscreds/web_identity_provider_test.go

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,79 @@ func (m mockErrorCode) Error() string {
3131
}
3232

3333
func TestWebIdentityProviderRetrieve(t *testing.T) {
34-
defer func() func() {
35-
o := sdk.NowTime
36-
sdk.NowTime = func() time.Time {
37-
return time.Time{}
38-
}
39-
return func() {
40-
sdk.NowTime = o
41-
}
42-
}()()
34+
restorTime := sdk.TestingUseReferenceTime(time.Time{})
35+
defer restorTime()
4336

4437
cases := map[string]struct {
4538
mockClient mockAssumeRoleWithWebIdentity
4639
roleARN string
4740
tokenFilepath string
4841
sessionName string
49-
expectedError error
42+
options func(*stscreds.WebIdentityRoleOptions)
5043
expectedCredValue aws.Credentials
5144
}{
52-
"session name case": {
45+
"success": {
5346
roleARN: "arn01234567890123456789",
5447
tokenFilepath: "testdata/token.jwt",
55-
sessionName: "foo",
56-
mockClient: func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
48+
options: func(o *stscreds.WebIdentityRoleOptions) {
49+
o.RoleSessionName = "foo"
50+
},
51+
mockClient: func(
52+
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
53+
) (
54+
*sts.AssumeRoleWithWebIdentityOutput, error,
55+
) {
56+
if e, a := "foo", *params.RoleSessionName; e != a {
57+
return nil, fmt.Errorf("expected %v, but received %v", e, a)
58+
}
59+
if params.DurationSeconds != nil {
60+
return nil, fmt.Errorf("expect no duration seconds, got %v",
61+
*params.DurationSeconds)
62+
}
63+
if params.Policy != nil {
64+
return nil, fmt.Errorf("expect no policy, got %v",
65+
*params.Policy)
66+
}
67+
return &sts.AssumeRoleWithWebIdentityOutput{
68+
Credentials: &types.Credentials{
69+
Expiration: aws.Time(sdk.NowTime()),
70+
AccessKeyId: aws.String("access-key-id"),
71+
SecretAccessKey: aws.String("secret-access-key"),
72+
SessionToken: aws.String("session-token"),
73+
},
74+
}, nil
75+
},
76+
expectedCredValue: aws.Credentials{
77+
AccessKeyID: "access-key-id",
78+
SecretAccessKey: "secret-access-key",
79+
SessionToken: "session-token",
80+
Source: stscreds.WebIdentityProviderName,
81+
CanExpire: true,
82+
Expires: sdk.NowTime(),
83+
},
84+
},
85+
"success with duration and policy": {
86+
roleARN: "arn01234567890123456789",
87+
tokenFilepath: "testdata/token.jwt",
88+
options: func(o *stscreds.WebIdentityRoleOptions) {
89+
o.Duration = 42 * time.Second
90+
o.Policy = aws.String("super secret policy")
91+
o.RoleSessionName = "foo"
92+
},
93+
mockClient: func(
94+
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
95+
) (
96+
*sts.AssumeRoleWithWebIdentityOutput, error,
97+
) {
5798
if e, a := "foo", *params.RoleSessionName; e != a {
5899
return nil, fmt.Errorf("expected %v, but received %v", e, a)
59100
}
101+
if e, a := int32(42), aws.ToInt32(params.DurationSeconds); e != a {
102+
return nil, fmt.Errorf("expect %v duration seconds, got %v", e, a)
103+
}
104+
if e, a := "super secret policy", aws.ToString(params.Policy); e != a {
105+
return nil, fmt.Errorf("expect %v policy, got %v", e, a)
106+
}
60107
return &sts.AssumeRoleWithWebIdentityOutput{
61108
Credentials: &types.Credentials{
62109
Expiration: aws.Time(sdk.NowTime()),
@@ -78,8 +125,14 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
78125
"configures token retry": {
79126
roleARN: "arn01234567890123456789",
80127
tokenFilepath: "testdata/token.jwt",
81-
sessionName: "foo",
82-
mockClient: func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
128+
options: func(o *stscreds.WebIdentityRoleOptions) {
129+
o.RoleSessionName = "foo"
130+
},
131+
mockClient: func(
132+
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
133+
) (
134+
*sts.AssumeRoleWithWebIdentityOutput, error,
135+
) {
83136
o := sts.Options{}
84137
for _, fn := range optFns {
85138
fn(&o)
@@ -112,13 +165,19 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
112165

113166
for name, c := range cases {
114167
t.Run(name, func(t *testing.T) {
115-
p := stscreds.NewWebIdentityRoleProvider(c.mockClient, c.roleARN, stscreds.IdentityTokenFile(c.tokenFilepath),
116-
func(o *stscreds.WebIdentityRoleOptions) {
117-
o.RoleSessionName = c.sessionName
118-
})
168+
var optFns []func(*stscreds.WebIdentityRoleOptions)
169+
if c.options != nil {
170+
optFns = append(optFns, c.options)
171+
}
172+
p := stscreds.NewWebIdentityRoleProvider(
173+
c.mockClient,
174+
c.roleARN,
175+
stscreds.IdentityTokenFile(c.tokenFilepath),
176+
optFns...,
177+
)
119178
credValue, err := p.Retrieve(context.Background())
120-
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
121-
t.Errorf("expected %v, but received %v", e, a)
179+
if err != nil {
180+
t.Fatalf("expect no error, got %v", err)
122181
}
123182

124183
if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {

0 commit comments

Comments
 (0)