Skip to content

Commit e5aa5a7

Browse files
authored
change: dynamically determine AWS domain based on region (#1299)
1 parent cce8eea commit e5aa5a7

File tree

6 files changed

+126
-33
lines changed

6 files changed

+126
-33
lines changed

src/sagemaker/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from datetime import datetime
2727
from functools import wraps
2828

29+
import botocore
2930
import six
3031
from six.moves.urllib import parse
31-
import botocore
3232

3333

3434
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$"
@@ -611,8 +611,8 @@ def get_ecr_image_uri_prefix(account, region):
611611
Returns:
612612
(str): URI prefix of ECR image
613613
"""
614-
domain = _domain_for_region(region)
615-
return "{}.dkr.ecr.{}.{}".format(account, region, domain)
614+
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
615+
return "{}.dkr.{}".format(account, endpoint_data["hostname"])
616616

617617

618618
def sts_regional_endpoint(region):
@@ -630,8 +630,8 @@ def sts_regional_endpoint(region):
630630
Returns:
631631
str: AWS STS regional endpoint
632632
"""
633-
domain = _domain_for_region(region)
634-
return "https://sts.{}.{}".format(region, domain)
633+
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
634+
return "https://{}".format(endpoint_data["hostname"])
635635

636636

637637
def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS):
@@ -654,7 +654,7 @@ def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_
654654
)
655655

656656

657-
def _domain_for_region(region):
657+
def _botocore_resolver():
658658
"""Get the DNS suffix for the given region.
659659
660660
Args:
@@ -663,7 +663,8 @@ def _domain_for_region(region):
663663
Returns:
664664
str: the DNS suffix
665665
"""
666-
return "c2s.ic.gov" if region == "us-iso-east-1" else "amazonaws.com"
666+
loader = botocore.loaders.create_loader()
667+
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))
667668

668669

669670
class DeferredError(object):

tests/unit/test_airflow.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def test_byo_training_config_all_args(sagemaker_session):
173173
]
174174
),
175175
)
176-
def test_framework_training_config_required_args(sagemaker_session):
176+
@patch(
177+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
178+
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
179+
)
180+
def test_framework_training_config_required_args(ecr_prefix, sagemaker_session):
177181
tf = tensorflow.TensorFlow(
178182
entry_point="/some/script.py",
179183
framework_version="1.10.0",
@@ -248,7 +252,11 @@ def test_framework_training_config_required_args(sagemaker_session):
248252
"sagemaker.estimator.parse_s3_url",
249253
MagicMock(return_value=["{{ output_path }}", "{{ output_path }}"]),
250254
)
251-
def test_framework_training_config_all_args(sagemaker_session):
255+
@patch(
256+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
257+
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
258+
)
259+
def test_framework_training_config_all_args(ecr_prefix, sagemaker_session):
252260
tf = tensorflow.TensorFlow(
253261
entry_point="{{ entry_point }}",
254262
source_dir="{{ source_dir }}",
@@ -478,7 +486,11 @@ def test_amazon_alg_training_config_all_args(sagemaker_session):
478486
]
479487
),
480488
)
481-
def test_framework_tuning_config(sagemaker_session):
489+
@patch(
490+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
491+
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
492+
)
493+
def test_framework_tuning_config(ecr_prefix, sagemaker_session):
482494
mxnet_estimator = mxnet.MXNet(
483495
entry_point="{{ entry_point }}",
484496
source_dir="{{ source_dir }}",
@@ -617,7 +629,15 @@ def test_framework_tuning_config(sagemaker_session):
617629
]
618630
),
619631
)
620-
def test_multi_estimator_tuning_config(sagemaker_session):
632+
@patch(
633+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
634+
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
635+
)
636+
@patch(
637+
"sagemaker.amazon.amazon_estimator.get_ecr_image_uri_prefix",
638+
return_value="174872318107.dkr.ecr.us-west-2.amazonaws.com",
639+
)
640+
def test_multi_estimator_tuning_config(algo_ecr_prefix, fw_ecr_prefix, sagemaker_session):
621641
estimator_dict = {}
622642
hyperparameter_ranges_dict = {}
623643
objective_metric_name_dict = {}
@@ -1025,7 +1045,11 @@ def test_amazon_alg_model_config(sagemaker_session):
10251045
]
10261046
),
10271047
)
1028-
def test_model_config_from_framework_estimator(sagemaker_session):
1048+
@patch(
1049+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
1050+
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
1051+
)
1052+
def test_model_config_from_framework_estimator(ecr_prefix, sagemaker_session):
10291053
mxnet_estimator = mxnet.MXNet(
10301054
entry_point="{{ entry_point }}",
10311055
source_dir="{{ source_dir }}",
@@ -1179,7 +1203,11 @@ def test_transform_config(sagemaker_session):
11791203
]
11801204
),
11811205
)
1182-
def test_transform_config_from_framework_estimator(sagemaker_session):
1206+
@patch(
1207+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
1208+
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
1209+
)
1210+
def test_transform_config_from_framework_estimator(ecr_prefix, sagemaker_session):
11831211
mxnet_estimator = mxnet.MXNet(
11841212
entry_point="{{ entry_point }}",
11851213
source_dir="{{ source_dir }}",
@@ -1420,7 +1448,11 @@ def test_deploy_amazon_alg_model_config(sagemaker_session):
14201448
]
14211449
),
14221450
)
1423-
def test_deploy_config_from_framework_estimator(sagemaker_session):
1451+
@patch(
1452+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
1453+
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
1454+
)
1455+
def test_deploy_config_from_framework_estimator(ecr_prefix, sagemaker_session):
14241456
mxnet_estimator = mxnet.MXNet(
14251457
entry_point="{{ entry_point }}",
14261458
source_dir="{{ source_dir }}",

tests/unit/test_fw_utils.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
REGION = "us-west-2"
3030
SCRIPT_PATH = "script.py"
3131
TIMESTAMP = "2017-10-10-14-14-15"
32+
ECR_PREFIX_FORMAT = "{}.dkr.ecr.mars-south-3.amazonaws.com"
3233

34+
MOCK_ACCOUNT = "520713654638"
3335
MOCK_FRAMEWORK = "mlfw"
3436
MOCK_REGION = "mars-south-3"
3537
MOCK_ACCELERATOR = "eia1.medium"
@@ -165,7 +167,9 @@ def sagemaker_session():
165167
return session_mock
166168

167169

168-
def test_create_image_uri_cpu():
170+
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix")
171+
def test_create_image_uri_cpu(ecr_prefix):
172+
ecr_prefix.return_value = ECR_PREFIX_FORMAT.format("23")
169173
image_uri = fw_utils.create_image_uri(
170174
MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23"
171175
)
@@ -176,20 +180,23 @@ def test_create_image_uri_cpu():
176180
)
177181
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
178182

183+
ecr_prefix.return_value = "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com"
179184
image_uri = fw_utils.create_image_uri(
180185
"us-gov-west-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
181186
)
182187
assert (
183188
image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
184189
)
185190

191+
ecr_prefix.return_value = "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov"
186192
image_uri = fw_utils.create_image_uri(
187193
"us-iso-east-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
188194
)
189195
assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2"
190196

191197

192-
def test_create_image_uri_no_python():
198+
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
199+
def test_create_image_uri_no_python(ecr_prefix):
193200
image_uri = fw_utils.create_image_uri(
194201
MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", account="23"
195202
)
@@ -201,7 +208,8 @@ def test_create_image_uri_bad_python():
201208
fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py0")
202209

203210

204-
def test_create_image_uri_gpu():
211+
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
212+
def test_create_image_uri_gpu(ecr_prefix):
205213
image_uri = fw_utils.create_image_uri(
206214
MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", "23"
207215
)
@@ -213,7 +221,8 @@ def test_create_image_uri_gpu():
213221
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
214222

215223

216-
def test_create_image_uri_accelerator_tfs():
224+
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
225+
def test_create_image_uri_accelerator_tfs(ecr_prefix):
217226
image_uri = fw_utils.create_image_uri(
218227
MOCK_REGION,
219228
"tensorflow-serving",
@@ -228,7 +237,11 @@ def test_create_image_uri_accelerator_tfs():
228237
)
229238

230239

231-
def test_create_image_uri_default_account():
240+
@patch(
241+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
242+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
243+
)
244+
def test_create_image_uri_default_account(ecr_prefix):
232245
image_uri = fw_utils.create_image_uri(
233246
MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3"
234247
)
@@ -511,7 +524,11 @@ def test_create_image_uri_tensorflow(tf_version):
511524
)
512525

513526

514-
def test_create_image_uri_accelerator_tf():
527+
@patch(
528+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
529+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
530+
)
531+
def test_create_image_uri_accelerator_tf(ecr_prefix):
515532
image_uri = fw_utils.create_image_uri(
516533
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
517534
)
@@ -521,7 +538,11 @@ def test_create_image_uri_accelerator_tf():
521538
)
522539

523540

524-
def test_create_image_uri_accelerator_mxnet_serving():
541+
@patch(
542+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
543+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
544+
)
545+
def test_create_image_uri_accelerator_mxnet_serving(ecr_prefix):
525546
image_uri = fw_utils.create_image_uri(
526547
MOCK_REGION,
527548
"mxnet-serving",
@@ -536,7 +557,11 @@ def test_create_image_uri_accelerator_mxnet_serving():
536557
)
537558

538559

539-
def test_create_image_uri_local_sagemaker_notebook_accelerator():
560+
@patch(
561+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
562+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
563+
)
564+
def test_create_image_uri_local_sagemaker_notebook_accelerator(ecr_prefix):
540565
image_uri = fw_utils.create_image_uri(
541566
MOCK_REGION,
542567
"mxnet",
@@ -608,7 +633,11 @@ def test_invalid_instance_type():
608633
fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "p3.2xlarge", "1.0.0", "py3")
609634

610635

611-
def test_optimized_family():
636+
@patch(
637+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
638+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
639+
)
640+
def test_optimized_family(ecr_prefix):
612641
image_uri = fw_utils.create_image_uri(
613642
MOCK_REGION,
614643
MOCK_FRAMEWORK,
@@ -622,7 +651,11 @@ def test_optimized_family():
622651
)
623652

624653

625-
def test_unoptimized_cpu_family():
654+
@patch(
655+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
656+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
657+
)
658+
def test_unoptimized_cpu_family(ecr_prefix):
626659
image_uri = fw_utils.create_image_uri(
627660
MOCK_REGION, MOCK_FRAMEWORK, "ml.m4.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"]
628661
)
@@ -631,7 +664,11 @@ def test_unoptimized_cpu_family():
631664
)
632665

633666

634-
def test_unoptimized_gpu_family():
667+
@patch(
668+
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
669+
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
670+
)
671+
def test_unoptimized_gpu_family(ecr_prefix):
635672
image_uri = fw_utils.create_image_uri(
636673
MOCK_REGION, MOCK_FRAMEWORK, "ml.p2.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"]
637674
)

tests/unit/test_processing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
BUCKET_NAME = "mybucket"
2929
REGION = "us-west-2"
3030
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
31+
ECR_PREFIX = "246618743249.dkr.ecr.us-west-2.amazonaws.com"
3132
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"
3233

3334
PROCESSING_JOB_DESCRIPTION = {
@@ -94,9 +95,12 @@ def sagemaker_session():
9495
return session_mock
9596

9697

98+
@patch("sagemaker.fw_registry.get_ecr_image_uri_prefix", return_value=ECR_PREFIX)
9799
@patch("os.path.exists", return_value=True)
98100
@patch("os.path.isfile", return_value=True)
99-
def test_sklearn_processor_with_required_parameters(exists_mock, isfile_mock, sagemaker_session):
101+
def test_sklearn_processor_with_required_parameters(
102+
exists_mock, isfile_mock, ecr_prefix, sagemaker_session
103+
):
100104
processor = SKLearnProcessor(
101105
role=ROLE,
102106
instance_type="ml.m4.xlarge",
@@ -117,9 +121,10 @@ def test_sklearn_processor_with_required_parameters(exists_mock, isfile_mock, sa
117121
sagemaker_session.process.assert_called_with(**expected_args)
118122

119123

124+
@patch("sagemaker.fw_registry.get_ecr_image_uri_prefix", return_value=ECR_PREFIX)
120125
@patch("os.path.exists", return_value=True)
121126
@patch("os.path.isfile", return_value=True)
122-
def test_sklearn_with_all_parameters(exists_mock, isfile_mock, sagemaker_session):
127+
def test_sklearn_with_all_parameters(exists_mock, isfile_mock, ecr_prefix, sagemaker_session):
123128
processor = SKLearnProcessor(
124129
role=ROLE,
125130
framework_version="0.20.0",

0 commit comments

Comments
 (0)