Skip to content

Commit 8ed3607

Browse files
committed
feat: export default role assumers from sts client
1 parent 012a48b commit 8ed3607

File tree

4 files changed

+114
-1
lines changed

4 files changed

+114
-1
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import { Credentials, Provider } from "@aws-sdk/types";
2+
3+
import { AssumeRoleCommand, AssumeRoleCommandInput } from "./commands/AssumeRoleCommand";
4+
import {
5+
AssumeRoleWithWebIdentityCommand,
6+
AssumeRoleWithWebIdentityCommandInput,
7+
} from "./commands/AssumeRoleWithWebIdentityCommand";
8+
import { STSClient, STSClientConfig } from "./STSClient";
9+
10+
type RoleAssumer = (sourceCreds: Credentials, params: AssumeRoleCommandInput) => Promise<Credentials>;
11+
12+
const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";
13+
14+
// Inject the fallback STS region of us-east-1.
15+
const decorateDefaultRegion = (region: STSClientConfig["region"]): STSClientConfig["region"] => {
16+
if (typeof region !== "function") {
17+
return region === undefined ? ASSUME_ROLE_DEFAULT_REGION : region;
18+
}
19+
return async () => {
20+
try {
21+
return await region();
22+
} catch (e) {
23+
return ASSUME_ROLE_DEFAULT_REGION;
24+
}
25+
};
26+
};
27+
28+
/**
29+
* The default role assumer that used by credential providers when STS.AssumeRole API is needed.
30+
*/
31+
const getDefaultAssumer = (stsOptions: STSClientConfig): RoleAssumer => {
32+
let stsClient: STSClient;
33+
return async (sourceCreds, params) => {
34+
if (!stsClient) {
35+
const { logger } = stsOptions;
36+
stsClient = new STSClient({
37+
logger,
38+
credentials: sourceCreds,
39+
region: decorateDefaultRegion(stsOptions.region),
40+
});
41+
}
42+
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
43+
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
44+
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
45+
}
46+
return {
47+
accessKeyId: Credentials.AccessKeyId,
48+
secretAccessKey: Credentials.SecretAccessKey,
49+
sessionToken: Credentials.SessionToken,
50+
expiration: Credentials.Expiration,
51+
};
52+
};
53+
};
54+
55+
type RoleAssumerWithWebIdentity = (params: AssumeRoleWithWebIdentityCommandInput) => Promise<Credentials>;
56+
57+
/**
58+
* The default role assumer that used by credential providers when STS.AssumeRole API is needed.
59+
*/
60+
const getDefaultAssumerWithWebIdentity = (stsOptions: STSClientConfig): RoleAssumerWithWebIdentity => {
61+
let stsClient: STSClient;
62+
return async (params) => {
63+
if (!stsClient) {
64+
const { logger } = stsOptions;
65+
stsClient = new STSClient({
66+
logger,
67+
region: decorateDefaultRegion(stsOptions.region),
68+
});
69+
}
70+
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
71+
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
72+
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
73+
}
74+
return {
75+
accessKeyId: Credentials.AccessKeyId,
76+
secretAccessKey: Credentials.SecretAccessKey,
77+
sessionToken: Credentials.SessionToken,
78+
expiration: Credentials.Expiration,
79+
};
80+
};
81+
};
82+
83+
type DefaultCredentialProvider = (input: any) => Provider<Credentials>;
84+
/**
85+
* The default credential providers requires STS client to assume role with desired API: sts:assumeRole,
86+
* sts:assumeRoleWithWebIdentity, etc. This function decorates the default credential provider with role assumers which
87+
* encapsulates the process of calling STS commands. This can only be imported by AWS client packages to avoid circular
88+
* dependencies.
89+
*
90+
* @internal
91+
*/
92+
export const decorateDefaultCredentialProvider = (provider: DefaultCredentialProvider): DefaultCredentialProvider => (
93+
input: any
94+
) =>
95+
provider({
96+
roleAssumer: getDefaultAssumer(input),
97+
roleAssumerWithWebIdentity: getDefaultAssumerWithWebIdentity(input),
98+
...input,
99+
});

clients/client-sts/package.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,11 @@
8989
"type": "git",
9090
"url": "https://github.com/aws/aws-sdk-js-v3.git",
9191
"directory": "clients/client-sts"
92+
},
93+
"exports": {
94+
"./defaultRoleAssumers": {
95+
"require": "./dist/cjs/defaultRoleAssumers.js",
96+
"import": "./dist/es/defaultRoleAssumers.js"
97+
}
9298
}
9399
}

clients/client-sts/runtimeConfig.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ export const ClientDefaultValues: Required<ClientDefaults> = {
2222
base64Decoder: fromBase64,
2323
base64Encoder: toBase64,
2424
bodyLengthChecker: calculateBodyLength,
25-
credentialDefaultProvider,
25+
credentialDefaultProvider: (input) => {
26+
/**
27+
* Inline require to avoid circular dependencies
28+
*/
29+
return require("./defaultRoleAssumers").decorateDefaultCredentialProvider(credentialDefaultProvider)(input);
30+
},
2631
defaultUserAgentProvider: defaultUserAgent({
2732
serviceId: ClientSharedValues.serviceId,
2833
clientVersion: packageInfo.version,

scripts/generate-clients/copy-to-clients.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ const copyToClients = async (sourceDir, destinationDir) => {
135135
directory: `clients/${clientName}`,
136136
},
137137
};
138+
if (clientName === "client-sts" && destManifest.exports) {
139+
mergedManifest["exports"] = destManifest.exports;
140+
}
138141
writeFileSync(destSubPath, JSON.stringify(mergedManifest, null, 2).concat(`\n`));
139142
} else if (overwritablePredicate(packageSub) || !existsSync(destSubPath)) {
140143
if (lstatSync(packageSubPath).isDirectory()) removeSync(destSubPath);

0 commit comments

Comments
 (0)