Skip to content

Commit f501fac

Browse files
committed
change: Allowing account overrides in special regions
1 parent 52c7475 commit f501fac

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

src/sagemaker/fw_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
5959
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634"}
6060
ASIMOV_OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "871362719292"}
61+
DEFAULT_ACCOUNT = "520713654638"
6162

6263
MERGED_FRAMEWORKS_REPO_MAP = {
6364
"tensorflow-scriptmode": "tensorflow-training",
@@ -183,7 +184,7 @@ def create_image_uri(
183184
instance_type,
184185
framework_version,
185186
py_version=None,
186-
account="520713654638",
187+
account=None,
187188
accelerator_type=None,
188189
optimized_families=None,
189190
):
@@ -218,13 +219,14 @@ def create_image_uri(
218219
framework += "-eia"
219220

220221
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
221-
account = _registry_id(
222-
region=region,
223-
framework=framework,
224-
py_version=py_version,
225-
account=account,
226-
framework_version=framework_version,
227-
)
222+
if account is None:
223+
account = _registry_id(
224+
region=region,
225+
framework=framework,
226+
py_version=py_version,
227+
account=DEFAULT_ACCOUNT,
228+
framework_version=framework_version,
229+
)
228230

229231
# Handle Local Mode
230232
if instance_type.startswith("local"):

tests/unit/test_fw_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def test_create_image_uri_cpu():
6868
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
6969

7070
image_uri = fw_utils.create_image_uri(
71-
"us-gov-west-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23"
71+
"us-gov-west-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
7272
)
7373
assert (
7474
image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
7575
)
7676

7777
image_uri = fw_utils.create_image_uri(
78-
"us-iso-east-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23"
78+
"us-iso-east-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2"
7979
)
8080
assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2"
8181

@@ -188,6 +188,25 @@ def test_mxnet_eia_images():
188188
== "763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-inference-eia:1.4.1-cpu-py3"
189189
)
190190

191+
def test_create_image_uri_override_account():
192+
image_uri = fw_utils.create_image_uri(
193+
"us-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"
194+
)
195+
assert image_uri == "fake.dkr.ecr.us-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
196+
197+
198+
def test_create_image_uri_gov_cloud_override_account():
199+
image_uri = fw_utils.create_image_uri(
200+
"us-gov-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"
201+
)
202+
assert image_uri == "fake.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
203+
204+
205+
def test_create_image_uri_hkg_override_account():
206+
image_uri = fw_utils.create_image_uri(
207+
MOCK_HKG_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", account="fake"
208+
)
209+
assert {image_uri == "fake.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"}
191210

192211
def test_create_image_uri_merged():
193212
image_uri = fw_utils.create_image_uri(

0 commit comments

Comments
 (0)