Skip to content

change: add account number and unit tests for govcloud #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def registry(region_name, algorithm=None):
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
"us-iso-east-1": "490574956308",
}[region_name]
elif algorithm in ["lda"]:
account_id = {
Expand All @@ -317,6 +318,7 @@ def registry(region_name, algorithm=None):
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
"us-iso-east-1": "490574956308",
}[region_name]
elif algorithm in ["forecasting-deepar"]:
account_id = {
Expand All @@ -334,6 +336,7 @@ def registry(region_name, algorithm=None):
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
"us-iso-east-1": "490574956308",
}[region_name]
elif algorithm in ["xgboost", "seq2seq", "image-classification", "blazingtext",
"object-detection", "semantic-segmentation"]:
Expand All @@ -352,17 +355,20 @@ def registry(region_name, algorithm=None):
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
"us-iso-east-1": "490574956308",
}[region_name]
elif algorithm in ['image-classification-neo', 'xgboost-neo']:
account_id = {
'us-west-2': '301217895009',
'us-east-1': '785573368785',
'eu-west-1': '802834080501',
'us-east-2': '007439368137'
'us-east-2': '007439368137',
}[region_name]
else:
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm))
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)

domain_name = "c2s.ic.gov" if region_name == "us-iso-east-1" else "amazonaws.com"
return "{}.dkr.ecr.{}.{}".format(account_id, region_name, domain_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use single quotes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually apply the principle to use same quotes as nearby codes. Codes in this file use double quotes. Do you want to me to change all of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all of them.



def get_image_uri(region_name, repo_name, repo_version=1):
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/fw_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@
"us-gov-west-1": {
"sparkml-serving": "414596584902",
"scikit-learn": "414596584902"
},
"us-iso-east-1": {
"sparkml-serving": "833128469047",
"scikit-learn": "833128469047"
}
}

Expand All @@ -80,7 +84,8 @@ def registry(region_name, framework=None):
"""
try:
account_id = image_registry_map[region_name][framework]
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
domain_name = "c2s.ic.gov" if region_name == "us-iso-east-1" else "amazonaws.com"
return "{}.dkr.ecr.{}.{}".format(account_id, region_name, domain_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe refactor this logic somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Wrapped in one function in utils.py

except KeyError:
logging.error("The specific image or region does not exist")
raise
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@

VALID_PY_VERSIONS = ['py2', 'py3']
VALID_EIA_FRAMEWORKS = ['tensorflow', 'mxnet']
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436'}
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436',
'us-iso-east-1': '744548109606'}


def create_image_uri(region, framework, instance_type, framework_version, py_version=None,
Expand Down Expand Up @@ -96,8 +97,10 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
optimized_families=optimized_families):
framework += '-eia'

return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
.format(account, region, framework, tag)
domain_name = "c2s.ic.gov" if region == "us-iso-east-1" else "amazonaws.com"

return "{}.dkr.ecr.{}.{}/sagemaker-{}:{}" \
.format(account, region, domain_name, framework, tag)


def _accelerator_type_valid_for_framework(framework, accelerator_type=None, optimized_families=None):
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# Use PCA as a test implementation of AmazonAlgorithmEstimator
from sagemaker.amazon.pca import PCA
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry, get_image_uri

COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'}

Expand Down Expand Up @@ -61,6 +61,14 @@ def sagemaker_session():
return sms


def test_gov_ecr_uri():
assert get_image_uri('us-gov-west-1', 'kmeans', 'latest') == \
'226302683700.dkr.ecr.us-gov-west-1.amazonaws.com/kmeans:latest'

assert get_image_uri('us-iso-east-1', 'kmeans', 'latest') == \
'490574956308.dkr.ecr.us-iso-east-1.c2s.ic.gov/kmeans:latest'


def test_init(sagemaker_session):
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
assert pca.num_components == 55
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_fw_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_registry_sparkml_serving():
assert registry('eu-central-1', 'sparkml-serving') == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
assert registry('ca-central-1', 'sparkml-serving') == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
assert registry('us-gov-west-1', 'sparkml-serving') == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
assert registry('us-iso-east-1', 'sparkml-serving') == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"


def test_registry_sklearn():
Expand All @@ -57,6 +58,7 @@ def test_registry_sklearn():
assert registry('eu-central-1', scikit_learn_framework_name) == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
assert registry('ca-central-1', scikit_learn_framework_name) == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
assert registry('us-gov-west-1', scikit_learn_framework_name) == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
assert registry('us-iso-east-1', scikit_learn_framework_name) == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"


def test_default_sklearn_image_uri():
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def test_create_image_uri_cpu():
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'local', '1.0rc', 'py2', '23')
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'

image_uri = fw_utils.create_image_uri('us-gov-west-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'

image_uri = fw_utils.create_image_uri('us-iso-east-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
assert image_uri == '744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2'


def test_create_image_uri_no_python():
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', account='23')
Expand Down