Skip to content

change: dynamically determine AWS domain based on region #1299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from datetime import datetime
from functools import wraps

import botocore
import six
from six.moves.urllib import parse
import botocore


ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$"
Expand Down Expand Up @@ -611,8 +611,8 @@ def get_ecr_image_uri_prefix(account, region):
Returns:
(str): URI prefix of ECR image
"""
domain = _domain_for_region(region)
return "{}.dkr.ecr.{}.{}".format(account, region, domain)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
return "{}.dkr.{}".format(account, endpoint_data["hostname"])


def sts_regional_endpoint(region):
Expand All @@ -630,8 +630,8 @@ def sts_regional_endpoint(region):
Returns:
str: AWS STS regional endpoint
"""
domain = _domain_for_region(region)
return "https://sts.{}.{}".format(region, domain)
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
return "https://{}".format(endpoint_data["hostname"])


def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS):
Expand All @@ -654,7 +654,7 @@ def retries(max_retry_count, exception_message_prefix, seconds_to_sleep=DEFAULT_
)


def _domain_for_region(region):
def _botocore_resolver():
"""Get the DNS suffix for the given region.

Args:
Expand All @@ -663,7 +663,8 @@ def _domain_for_region(region):
Returns:
str: the DNS suffix
"""
return "c2s.ic.gov" if region == "us-iso-east-1" else "amazonaws.com"
loader = botocore.loaders.create_loader()
return botocore.regions.EndpointResolver(loader.load_data("endpoints"))


class DeferredError(object):
Expand Down
46 changes: 39 additions & 7 deletions tests/unit/test_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def test_byo_training_config_all_args(sagemaker_session):
]
),
)
def test_framework_training_config_required_args(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
)
def test_framework_training_config_required_args(ecr_prefix, sagemaker_session):
tf = tensorflow.TensorFlow(
entry_point="/some/script.py",
framework_version="1.10.0",
Expand Down Expand Up @@ -248,7 +252,11 @@ def test_framework_training_config_required_args(sagemaker_session):
"sagemaker.estimator.parse_s3_url",
MagicMock(return_value=["{{ output_path }}", "{{ output_path }}"]),
)
def test_framework_training_config_all_args(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
)
def test_framework_training_config_all_args(ecr_prefix, sagemaker_session):
tf = tensorflow.TensorFlow(
entry_point="{{ entry_point }}",
source_dir="{{ source_dir }}",
Expand Down Expand Up @@ -478,7 +486,11 @@ def test_amazon_alg_training_config_all_args(sagemaker_session):
]
),
)
def test_framework_tuning_config(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
)
def test_framework_tuning_config(ecr_prefix, sagemaker_session):
mxnet_estimator = mxnet.MXNet(
entry_point="{{ entry_point }}",
source_dir="{{ source_dir }}",
Expand Down Expand Up @@ -617,7 +629,15 @@ def test_framework_tuning_config(sagemaker_session):
]
),
)
def test_multi_estimator_tuning_config(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="520713654638.dkr.ecr.us-west-2.amazonaws.com",
)
@patch(
"sagemaker.amazon.amazon_estimator.get_ecr_image_uri_prefix",
return_value="174872318107.dkr.ecr.us-west-2.amazonaws.com",
)
def test_multi_estimator_tuning_config(algo_ecr_prefix, fw_ecr_prefix, sagemaker_session):
estimator_dict = {}
hyperparameter_ranges_dict = {}
objective_metric_name_dict = {}
Expand Down Expand Up @@ -1025,7 +1045,11 @@ def test_amazon_alg_model_config(sagemaker_session):
]
),
)
def test_model_config_from_framework_estimator(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
)
def test_model_config_from_framework_estimator(ecr_prefix, sagemaker_session):
mxnet_estimator = mxnet.MXNet(
entry_point="{{ entry_point }}",
source_dir="{{ source_dir }}",
Expand Down Expand Up @@ -1179,7 +1203,11 @@ def test_transform_config(sagemaker_session):
]
),
)
def test_transform_config_from_framework_estimator(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
)
def test_transform_config_from_framework_estimator(ecr_prefix, sagemaker_session):
mxnet_estimator = mxnet.MXNet(
entry_point="{{ entry_point }}",
source_dir="{{ source_dir }}",
Expand Down Expand Up @@ -1420,7 +1448,11 @@ def test_deploy_amazon_alg_model_config(sagemaker_session):
]
),
)
def test_deploy_config_from_framework_estimator(sagemaker_session):
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value="763104351884.dkr.ecr.us-west-2.amazonaws.com",
)
def test_deploy_config_from_framework_estimator(ecr_prefix, sagemaker_session):
mxnet_estimator = mxnet.MXNet(
entry_point="{{ entry_point }}",
source_dir="{{ source_dir }}",
Expand Down
59 changes: 48 additions & 11 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
REGION = "us-west-2"
SCRIPT_PATH = "script.py"
TIMESTAMP = "2017-10-10-14-14-15"
ECR_PREFIX_FORMAT = "{}.dkr.ecr.mars-south-3.amazonaws.com"

MOCK_ACCOUNT = "520713654638"
MOCK_FRAMEWORK = "mlfw"
MOCK_REGION = "mars-south-3"
MOCK_ACCELERATOR = "eia1.medium"
Expand Down Expand Up @@ -165,7 +167,9 @@ def sagemaker_session():
return session_mock


def test_create_image_uri_cpu():
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix")
def test_create_image_uri_cpu(ecr_prefix):
ecr_prefix.return_value = ECR_PREFIX_FORMAT.format("23")
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23"
)
Expand All @@ -176,20 +180,23 @@ def test_create_image_uri_cpu():
)
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"

ecr_prefix.return_value = "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com"
image_uri = fw_utils.create_image_uri(
"us-gov-west-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
)
assert (
image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
)

ecr_prefix.return_value = "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov"
image_uri = fw_utils.create_image_uri(
"us-iso-east-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
)
assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2"


def test_create_image_uri_no_python():
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
def test_create_image_uri_no_python(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", account="23"
)
Expand All @@ -201,7 +208,8 @@ def test_create_image_uri_bad_python():
fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py0")


def test_create_image_uri_gpu():
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
def test_create_image_uri_gpu(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", "23"
)
Expand All @@ -213,7 +221,8 @@ def test_create_image_uri_gpu():
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"


def test_create_image_uri_accelerator_tfs():
@patch("sagemaker.fw_utils.get_ecr_image_uri_prefix", return_value=ECR_PREFIX_FORMAT.format("23"))
def test_create_image_uri_accelerator_tfs(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION,
"tensorflow-serving",
Expand All @@ -228,7 +237,11 @@ def test_create_image_uri_accelerator_tfs():
)


def test_create_image_uri_default_account():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_create_image_uri_default_account(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3"
)
Expand Down Expand Up @@ -511,7 +524,11 @@ def test_create_image_uri_tensorflow(tf_version):
)


def test_create_image_uri_accelerator_tf():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_create_image_uri_accelerator_tf(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
)
Expand All @@ -521,7 +538,11 @@ def test_create_image_uri_accelerator_tf():
)


def test_create_image_uri_accelerator_mxnet_serving():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_create_image_uri_accelerator_mxnet_serving(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION,
"mxnet-serving",
Expand All @@ -536,7 +557,11 @@ def test_create_image_uri_accelerator_mxnet_serving():
)


def test_create_image_uri_local_sagemaker_notebook_accelerator():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_create_image_uri_local_sagemaker_notebook_accelerator(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION,
"mxnet",
Expand Down Expand Up @@ -608,7 +633,11 @@ def test_invalid_instance_type():
fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "p3.2xlarge", "1.0.0", "py3")


def test_optimized_family():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_optimized_family(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION,
MOCK_FRAMEWORK,
Expand All @@ -622,7 +651,11 @@ def test_optimized_family():
)


def test_unoptimized_cpu_family():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_unoptimized_cpu_family(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.m4.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"]
)
Expand All @@ -631,7 +664,11 @@ def test_unoptimized_cpu_family():
)


def test_unoptimized_gpu_family():
@patch(
"sagemaker.fw_utils.get_ecr_image_uri_prefix",
return_value=ECR_PREFIX_FORMAT.format(MOCK_ACCOUNT),
)
def test_unoptimized_gpu_family(ecr_prefix):
image_uri = fw_utils.create_image_uri(
MOCK_REGION, MOCK_FRAMEWORK, "ml.p2.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"]
)
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BUCKET_NAME = "mybucket"
REGION = "us-west-2"
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
ECR_PREFIX = "246618743249.dkr.ecr.us-west-2.amazonaws.com"
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"

PROCESSING_JOB_DESCRIPTION = {
Expand Down Expand Up @@ -94,9 +95,12 @@ def sagemaker_session():
return session_mock


@patch("sagemaker.fw_registry.get_ecr_image_uri_prefix", return_value=ECR_PREFIX)
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_sklearn_processor_with_required_parameters(exists_mock, isfile_mock, sagemaker_session):
def test_sklearn_processor_with_required_parameters(
exists_mock, isfile_mock, ecr_prefix, sagemaker_session
):
processor = SKLearnProcessor(
role=ROLE,
instance_type="ml.m4.xlarge",
Expand All @@ -117,9 +121,10 @@ def test_sklearn_processor_with_required_parameters(exists_mock, isfile_mock, sa
sagemaker_session.process.assert_called_with(**expected_args)


@patch("sagemaker.fw_registry.get_ecr_image_uri_prefix", return_value=ECR_PREFIX)
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_sklearn_with_all_parameters(exists_mock, isfile_mock, sagemaker_session):
def test_sklearn_with_all_parameters(exists_mock, isfile_mock, ecr_prefix, sagemaker_session):
processor = SKLearnProcessor(
role=ROLE,
framework_version="0.20.0",
Expand Down
Loading