Skip to content

Commit 228a81d

Browse files
authored
change: use regional endpoint when creating AWS STS client (#1026)
1 parent ded3c8f commit 228a81d

File tree

4 files changed

+48
-10
lines changed

4 files changed

+48
-10
lines changed

src/sagemaker/session.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
name_from_image,
3838
secondary_training_status_changed,
3939
secondary_training_status_message,
40+
sts_regional_endpoint,
4041
)
4142
from sagemaker import exceptions
4243

@@ -1377,10 +1378,13 @@ def expand_role(self, role):
13771378

13781379
def get_caller_identity_arn(self):
13791380
"""Returns the ARN user or role whose credentials are used to call the API.
1381+
13801382
Returns:
1381-
(str): The ARN user or role
1383+
str: The ARN user or role
13821384
"""
1383-
assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"]
1385+
assumed_role = self.boto_session.client(
1386+
"sts", endpoint_url=sts_regional_endpoint(self.boto_region_name)
1387+
).get_caller_identity()["Arn"]
13841388

13851389
if "AmazonSageMaker-ExecutionRole" in assumed_role:
13861390
role = re.sub(

src/sagemaker/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -538,6 +538,24 @@ def get_ecr_image_uri_prefix(account, region):
538538
return "{}.dkr.ecr.{}.{}".format(account, region, domain)
539539

540540

541+
def sts_regional_endpoint(region):
542+
"""Get the AWS STS endpoint specific for the given region.
543+
544+
We need this function because the AWS SDK does not yet honor
545+
the ``region_name`` parameter when creating an AWS STS client.
546+
547+
For the list of regional endpoints, see
548+
https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_region-endpoints.
549+
550+
Args:
551+
region (str): AWS region name
552+
553+
Returns:
554+
str: AWS STS regional endpoint
555+
"""
556+
return "sts.{}.amazonaws.com".format(region)
557+
558+
541559
class DeferredError(object):
542560
"""Stores an exception and raises it at a later time if this object is
543561
accessed in any way. Useful to allow soft-dependencies on imports, so that

tests/unit/test_session.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -31,6 +31,7 @@
3131
SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}]
3232

3333
REGION = "us-west-2"
34+
STS_ENDPOINT = "sts.us-west-2.amazonaws.com"
3435

3536

3637
@pytest.fixture()
@@ -88,7 +89,9 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam
8889
def test_get_caller_identity_arn_from_an_user(boto_session):
8990
sess = Session(boto_session)
9091
arn = "arn:aws:iam::369233609183:user/mia"
91-
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
92+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
93+
"Arn": arn
94+
}
9295
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}
9396

9497
actual = sess.get_caller_identity_arn()
@@ -98,7 +101,9 @@ def test_get_caller_identity_arn_from_an_user(boto_session):
98101
def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session):
99102
sess = Session(boto_session)
100103
arn = "arn:aws:iam::369233609183:user/mia"
101-
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
104+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
105+
"Arn": arn
106+
}
102107
sess.boto_session.client("iam").get_role.side_effect = ClientError({}, {})
103108

104109
with patch("logging.Logger.warning") as mock_logger:
@@ -112,7 +117,9 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
112117
arn = (
113118
"arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122"
114119
)
115-
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
120+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
121+
"Arn": arn
122+
}
116123

117124
expected_role = "arn:aws:iam::369233609183:role/SageMakerRole"
118125
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}}
@@ -124,7 +131,9 @@ def test_get_caller_identity_arn_from_a_role(boto_session):
124131
def test_get_caller_identity_arn_from_a_execution_role(boto_session):
125132
sess = Session(boto_session)
126133
arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker"
127-
sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn}
134+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
135+
"Arn": arn
136+
}
128137
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}
129138

130139
actual = sess.get_caller_identity_arn()
@@ -138,7 +147,7 @@ def test_get_caller_identity_arn_from_role_with_path(boto_session):
138147
sess = Session(boto_session)
139148
arn_prefix = "arn:aws:iam::369233609183:role"
140149
role_name = "name"
141-
sess.boto_session.client("sts").get_caller_identity.return_value = {
150+
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
142151
"Arn": "/".join([arn_prefix, role_name])
143152
}
144153

@@ -344,7 +353,9 @@ def test_s3_input_all_arguments():
344353
@pytest.fixture()
345354
def sagemaker_session():
346355
boto_mock = Mock(name="boto_session")
347-
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}
356+
boto_mock.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
357+
"Account": "123"
358+
}
348359
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
349360
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
350361
return ims

tests/unit/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,8 @@ def walk():
560560

561561
result = set(walk())
562562
return result if result else {}
563+
564+
565+
def test_sts_regional_endpoint():
566+
endpoint = sagemaker.utils.sts_regional_endpoint("us-west-2")
567+
assert endpoint == "sts.us-west-2.amazonaws.com"

0 commit comments

Comments
 (0)