Skip to content

Commit b566052

Browse files
laurenyunadiayaajaykarpur
committed
change: refactor logic in fw_utils and fill in docstrings (#1192)
* change: refactor logic in fw_utils and fill in docstrings * black format * fix MXNet serving EI version * use full version for _LOWEST_MMS_VERSION Co-authored-by: Nadia Yakimakha <[email protected]> Co-authored-by: Ajay Karpur <[email protected]>
1 parent a25d5a9 commit b566052

File tree

2 files changed

+62
-91
lines changed

2 files changed

+62
-91
lines changed

src/sagemaker/fw_utils.py

Lines changed: 61 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16-
from collections import namedtuple
17-
1816
import os
1917
import re
2018
import shutil
2119
import tempfile
20+
from collections import namedtuple
2221

2322
import sagemaker.utils
24-
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
2523
from sagemaker import s3
24+
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
2625

2726
_TAR_SOURCE_FILENAME = "source.tar.gz"
2827

@@ -69,127 +68,96 @@
6968
"tensorflow-serving-eia": "tensorflow-inference-eia",
7069
"mxnet": "mxnet-training",
7170
"mxnet-serving": "mxnet-inference",
71+
"mxnet-serving-eia": "mxnet-inference-eia",
7272
"pytorch": "pytorch-training",
7373
"pytorch-serving": "pytorch-inference",
74-
"mxnet-serving-eia": "mxnet-inference-eia",
7574
}
7675

7776
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
78-
"tensorflow-scriptmode": [1, 13, 1],
77+
"tensorflow-scriptmode": {"py3": [1, 13, 1], "py2": [1, 14, 0]},
7978
"tensorflow-serving": [1, 13, 0],
8079
"tensorflow-serving-eia": [1, 14, 0],
81-
"mxnet": [1, 4, 1],
82-
"mxnet-serving": [1, 4, 1],
80+
"mxnet": {"py3": [1, 4, 1], "py2": [1, 6, 0]},
81+
"mxnet-serving": {"py3": [1, 4, 1], "py2": [1, 6, 0]},
82+
"mxnet-serving-eia": [1, 4, 1],
8383
"pytorch": [1, 2, 0],
8484
"pytorch-serving": [1, 2, 0],
85-
"mxnet-serving-eia": [1, 4, 1],
8685
}
8786

8887

8988
def is_version_equal_or_higher(lowest_version, framework_version):
9089
"""Determine whether the ``framework_version`` is equal to or higher than
9190
``lowest_version``
91+
9292
Args:
9393
lowest_version (List[int]): lowest version represented in an integer
9494
list
9595
framework_version (str): framework version string
96+
9697
Returns:
97-
bool: Whether or not framework_version is equal to or higher than
98-
lowest_version
98+
bool: Whether or not ``framework_version`` is equal to or higher than
99+
``lowest_version``
99100
"""
100101
version_list = [int(s) for s in framework_version.split(".")]
101102
return version_list >= lowest_version[0 : len(version_list)]
102103

103104

104-
def _is_merged_versions(framework, framework_version):
105-
"""
105+
def _is_dlc_version(framework, framework_version, py_version):
106+
"""Return if the framework's version uses the corresponding DLC image.
107+
106108
Args:
107-
framework:
108-
framework_version:
109+
framework (str): The framework name, e.g. "tensorflow-scriptmode"
110+
framework_version (str): The framework version
111+
py_version (str): The Python version, e.g. "py3"
112+
113+
Returns:
114+
bool: Whether or not the framework's version uses the DLC image.
109115
"""
110116
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
117+
if isinstance(lowest_version_list, dict):
118+
lowest_version_list = lowest_version_list[py_version]
119+
111120
if lowest_version_list:
112121
return is_version_equal_or_higher(lowest_version_list, framework_version)
113122
return False
114123

115124

116-
def _using_merged_images(region, framework, py_version, framework_version):
117-
"""
118-
Args:
119-
region:
120-
framework:
121-
py_version:
122-
accelerator_type:
123-
framework_version:
124-
"""
125-
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
126-
not_py2 = py_version == "py3" or py_version is None
127-
is_merged_versions = _is_merged_versions(framework, framework_version)
128-
129-
return (
130-
((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION)
131-
and is_merged_versions
132-
# TODO: should be not mxnet-1.14.1-py2 instead?
133-
and (
134-
not_py2
135-
or _is_tf_14_or_later(framework, framework_version)
136-
or _is_pt_12_or_later(framework, framework_version)
137-
or _is_mxnet_16_or_later(framework, framework_version)
138-
)
139-
)
140-
125+
def _use_dlc_image(region, framework, py_version, framework_version):
126+
"""Return if the DLC image should be used for the given framework,
127+
framework version, Python version, and region.
141128
142-
def _is_tf_14_or_later(framework, framework_version):
143-
"""
144129
Args:
145-
framework:
146-
framework_version:
147-
"""
148-
# Asimov team now owns Tensorflow 1.14.0 py2 and py3
149-
asimov_lowest_tf_py2 = [1, 14, 0]
150-
version = [int(s) for s in framework_version.split(".")]
151-
return (
152-
framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)]
153-
)
154-
130+
region (str): The AWS region.
131+
framework (str): The framework name, e.g. "tensorflow-scriptmode".
132+
py_version (str): The Python version, e.g. "py3".
133+
framework_version (str): The framework version.
155134
156-
def _is_pt_12_or_later(framework, framework_version):
157-
"""
158-
Args:
159-
framework: Name of the frameowork
160-
framework_version: framework version
135+
Returns:
136+
bool: Whether or not to use the corresponding DLC image.
161137
"""
162-
# Asimov team now owns PyTorch 1.2.0 py2 and py3
163-
asimov_lowest_pt = [1, 2, 0]
164-
version = [int(s) for s in framework_version.split(".")]
165-
is_pytorch = framework in ("pytorch", "pytorch-serving")
166-
return is_pytorch and version >= asimov_lowest_pt[0 : len(version)]
138+
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
139+
is_dlc_version = _is_dlc_version(framework, framework_version, py_version)
167140

168-
169-
def _is_mxnet_16_or_later(framework, framework_version):
170-
"""
171-
Args:
172-
framework: Name of the frameowork
173-
framework_version: framework version
174-
"""
175-
# Asimov team now owns MXNet 1.6.0 py2 and py3
176-
asimov_lowest_pt = [1, 6, 0]
177-
version = [int(s) for s in framework_version.split(".")]
178-
is_mxnet = framework in ("mxnet", "mxnet-serving")
179-
return is_mxnet and version >= asimov_lowest_pt[0 : len(version)]
141+
return ((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION) and is_dlc_version
180142

181143

182144
def _registry_id(region, framework, py_version, account, framework_version):
183-
"""
145+
"""Return the Amazon ECR registry number (or AWS account ID) for
146+
the given framework, framework version, Python version, and region.
147+
184148
Args:
185-
region:
186-
framework:
187-
py_version:
188-
account:
189-
accelerator_type:
190-
framework_version:
149+
region (str): The AWS region.
150+
framework (str): The framework name, e.g. "tensorflow-scriptmode".
151+
py_version (str): The Python version, e.g. "py3".
152+
account (str): The AWS account ID to use as a default.
153+
framework_version (str): The framework version.
154+
155+
Returns:
156+
str: The appropriate Amazon ECR registry number. If there is no
157+
specific one for the framework, framework version, Python version,
158+
and region, then ``account`` is returned.
191159
"""
192-
if _using_merged_images(region, framework, py_version, framework_version):
160+
if _use_dlc_image(region, framework, py_version, framework_version):
193161
if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION:
194162
return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION.get(region)
195163
if region in ASIMOV_VALID_ACCOUNTS_BY_REGION:
@@ -211,6 +179,7 @@ def create_image_uri(
211179
optimized_families=None,
212180
):
213181
"""Return the ECR URI of an image.
182+
214183
Args:
215184
region (str): AWS region where the image is uploaded.
216185
framework (str): framework used by the image.
@@ -225,6 +194,7 @@ def create_image_uri(
225194
accelerator_type (str): SageMaker Elastic Inference accelerator type.
226195
optimized_families (str): Instance families for which there exist
227196
specific optimized images.
197+
228198
Returns:
229199
str: The appropriate image URI based on the given parameters.
230200
"""
@@ -240,7 +210,7 @@ def create_image_uri(
240210
):
241211
framework += "-eia"
242212

243-
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
213+
# Handle account number for specific cases (e.g. GovCloud, opt-in regions, DLC images etc.)
244214
if account is None:
245215
account = _registry_id(
246216
region=region,
@@ -271,18 +241,19 @@ def create_image_uri(
271241
else:
272242
device_type = "cpu"
273243

274-
using_merged_images = _using_merged_images(region, framework, py_version, framework_version)
244+
use_dlc_image = _use_dlc_image(region, framework, py_version, framework_version)
275245

276-
if not py_version or (using_merged_images and framework == "tensorflow-serving-eia"):
246+
if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia"):
277247
tag = "{}-{}".format(framework_version, device_type)
278248
else:
279249
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
280250

281-
if using_merged_images:
282-
return "{}/{}:{}".format(
283-
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
284-
)
285-
return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)
251+
if use_dlc_image:
252+
ecr_repo = MERGED_FRAMEWORKS_REPO_MAP[framework]
253+
else:
254+
ecr_repo = "sagemaker-{}".format(framework)
255+
256+
return "{}/{}:{}".format(get_ecr_image_uri_prefix(account, region), ecr_repo, tag)
286257

287258

288259
def _accelerator_type_valid_for_framework(

src/sagemaker/mxnet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class MXNetModel(FrameworkModel):
5353
"""An MXNet SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
5454

5555
__framework_name__ = "mxnet"
56-
_LOWEST_MMS_VERSION = "1.4"
56+
_LOWEST_MMS_VERSION = "1.4.0"
5757

5858
def __init__(
5959
self,

0 commit comments

Comments
 (0)