Skip to content

Commit a3dd208

Browse files
committed
feature: use deep learnining images
1 parent 7c1bdf3 commit a3dd208

File tree

3 files changed

+114
-16
lines changed

3 files changed

+114
-16
lines changed

src/sagemaker/fw_utils.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,44 @@
5454
VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
5555
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
5656

57+
MERGED_FRAMEWORKS_REPO_MAP = {
58+
"tensorflow-scriptmode": "tensorflow-training",
59+
"mxnet": "mxnet-training",
60+
"tensorflow-serving": "tensorflow-inference",
61+
"mxnet-serving": "mxnet-inference",
62+
}
63+
64+
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
65+
"tensorflow-scriptmode": [1, 13, 1],
66+
"mxnet": [1, 4, 1],
67+
"tensorflow-serving": [1, 13, 0],
68+
"mxnet-serving": [1, 4, 1],
69+
}
70+
71+
72+
def _is_merged_versions(framework, framework_version):
73+
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
74+
if lowest_version_list:
75+
version_list = [int(s) for s in framework_version.split(".")]
76+
return version_list >= lowest_version_list[0 : len(version_list)]
77+
else:
78+
return False
79+
80+
81+
def _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
82+
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
83+
is_py3 = py_version == "py3" or py_version is None
84+
is_merged_versions = _is_merged_versions(framework, framework_version)
85+
return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None
86+
87+
88+
def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
89+
90+
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
91+
return "763104351884"
92+
else:
93+
return VALID_ACCOUNTS_BY_REGION.get(region, account)
94+
5795

5896
def create_image_uri(
5997
region,
@@ -86,8 +124,15 @@ def create_image_uri(
86124
if py_version and py_version not in VALID_PY_VERSIONS:
87125
raise ValueError("invalid py_version argument: {}".format(py_version))
88126

89-
# Handle Account Number for Gov Cloud
90-
account = VALID_ACCOUNTS_BY_REGION.get(region, account)
127+
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
128+
account = _registry_id(
129+
region=region,
130+
framework=framework,
131+
py_version=py_version,
132+
account=account,
133+
accelerator_type=accelerator_type,
134+
framework_version=framework_version,
135+
)
91136

92137
# Handle Local Mode
93138
if instance_type.startswith("local"):
@@ -121,7 +166,14 @@ def create_image_uri(
121166
):
122167
framework += "-eia"
123168

124-
return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)
169+
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
170+
return "{}/{}:{}".format(
171+
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
172+
)
173+
else:
174+
return "{}/sagemaker-{}:{}".format(
175+
get_ecr_image_uri_prefix(account, region), framework, tag
176+
)
125177

126178

127179
def _accelerator_type_valid_for_framework(
@@ -264,7 +316,7 @@ def framework_name_from_image(image_name):
264316
# extract framework, python version and image tag
265317
# We must support both the legacy and current image name format.
266318
name_pattern = re.compile(
267-
r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
319+
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
268320
)
269321
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
270322

tests/unit/test_fw_utils.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,59 @@ def test_create_image_uri_gov_cloud():
136136
)
137137

138138

139+
def test_create_image_uri_merged():
140+
image_uri = fw_utils.create_image_uri(
141+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3"
142+
)
143+
assert (
144+
image_uri
145+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13.1-gpu-py3"
146+
)
147+
148+
image_uri = fw_utils.create_image_uri(
149+
"us-west-2", "tensorflow-serving", "ml.c4.2xlarge", "1.13.1"
150+
)
151+
assert (
152+
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.13.1-cpu"
153+
)
154+
155+
image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3")
156+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.4.1-gpu-py3"
157+
158+
image_uri = fw_utils.create_image_uri(
159+
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3"
160+
)
161+
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3"
162+
163+
164+
def test_create_image_uri_merged_py2():
165+
image_uri = fw_utils.create_image_uri(
166+
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py2"
167+
)
168+
assert (
169+
image_uri
170+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2"
171+
)
172+
173+
image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2")
174+
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2"
175+
176+
image_uri = fw_utils.create_image_uri(
177+
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py2"
178+
)
179+
assert (
180+
image_uri
181+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py2"
182+
)
183+
184+
139185
def test_create_image_uri_accelerator_tf():
140186
image_uri = fw_utils.create_image_uri(
141-
MOCK_REGION,
142-
"tensorflow",
143-
"ml.p3.2xlarge",
144-
"1.0rc",
145-
"py3",
146-
accelerator_type="ml.eia1.medium",
187+
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
147188
)
148189
assert (
149190
image_uri
150-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3"
191+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0-gpu-py3"
151192
)
152193

153194

@@ -156,13 +197,13 @@ def test_create_image_uri_accelerator_mxnet_serving():
156197
MOCK_REGION,
157198
"mxnet-serving",
158199
"ml.p3.2xlarge",
159-
"1.0rc",
200+
"1.0",
160201
"py3",
161202
accelerator_type="ml.eia1.medium",
162203
)
163204
assert (
164205
image_uri
165-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3"
206+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0-gpu-py3"
166207
)
167208

168209

@@ -171,13 +212,13 @@ def test_create_image_uri_local_sagemaker_notebook_accelerator():
171212
MOCK_REGION,
172213
"mxnet",
173214
"ml.p3.2xlarge",
174-
"1.0rc",
215+
"1.0",
175216
"py3",
176217
accelerator_type="local_sagemaker_notebook",
177218
)
178219
assert (
179220
image_uri
180-
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3"
221+
== "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0-gpu-py3"
181222
)
182223

183224

@@ -555,6 +596,11 @@ def test_framework_name_from_image_tf_scriptmode():
555596
"scriptmode",
556597
) == fw_utils.framework_name_from_image(image_name)
557598

599+
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13-cpu-py3"
600+
assert ("tensorflow", "py3", "1.13-cpu-py3", "training") == fw_utils.framework_name_from_image(
601+
image_name
602+
)
603+
558604

559605
def test_framework_name_from_image_rl():
560606
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3"

tests/unit/test_tf_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def test_script_mode_tensorboard(
924924
sagemaker_session=sagemaker_session,
925925
train_instance_count=INSTANCE_COUNT,
926926
train_instance_type=INSTANCE_TYPE,
927-
framework_version="some_version",
927+
framework_version="1.0",
928928
script_mode=True,
929929
)
930930
popen().poll.return_value = None

0 commit comments

Comments
 (0)