Skip to content

change: refactor logic in fw_utils and fill in docstrings #1192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 61 additions & 90 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
"""Utility methods used by framework classes"""
from __future__ import absolute_import

from collections import namedtuple

import os
import re
import shutil
import tempfile
from collections import namedtuple

import sagemaker.utils
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
from sagemaker import s3
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN

_TAR_SOURCE_FILENAME = "source.tar.gz"

Expand Down Expand Up @@ -69,127 +68,96 @@
"tensorflow-serving-eia": "tensorflow-inference-eia",
"mxnet": "mxnet-training",
"mxnet-serving": "mxnet-inference",
"mxnet-serving-eia": "mxnet-inference-eia",
"pytorch": "pytorch-training",
"pytorch-serving": "pytorch-inference",
"mxnet-serving-eia": "mxnet-inference-eia",
}

MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
"tensorflow-scriptmode": [1, 13, 1],
"tensorflow-scriptmode": {"py3": [1, 13, 1], "py2": [1, 14, 0]},
"tensorflow-serving": [1, 13, 0],
"tensorflow-serving-eia": [1, 14, 0],
"mxnet": [1, 4, 1],
"mxnet-serving": [1, 4, 1],
"mxnet": {"py3": [1, 4, 1], "py2": [1, 6, 0]},
"mxnet-serving": {"py3": [1, 4, 1], "py2": [1, 6, 0]},
"mxnet-serving-eia": [1, 4, 1],
"pytorch": [1, 2, 0],
"pytorch-serving": [1, 2, 0],
"mxnet-serving-eia": [1, 4, 1],
}


def is_version_equal_or_higher(lowest_version, framework_version):
"""Determine whether the ``framework_version`` is equal to or higher than
``lowest_version``

Args:
lowest_version (List[int]): lowest version represented in an integer
list
framework_version (str): framework version string

Returns:
bool: Whether or not framework_version is equal to or higher than
lowest_version
bool: Whether or not ``framework_version`` is equal to or higher than
``lowest_version``
"""
version_list = [int(s) for s in framework_version.split(".")]
return version_list >= lowest_version[0 : len(version_list)]


def _is_merged_versions(framework, framework_version):
"""
def _is_dlc_version(framework, framework_version, py_version):
"""Return if the framework's version uses the corresponding DLC image.

Args:
framework:
framework_version:
framework (str): The framework name, e.g. "tensorflow-scriptmode"
framework_version (str): The framework version
py_version (str): The Python version, e.g. "py3"

Returns:
bool: Whether or not the framework's version uses the DLC image.
"""
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
if isinstance(lowest_version_list, dict):
lowest_version_list = lowest_version_list[py_version]

if lowest_version_list:
return is_version_equal_or_higher(lowest_version_list, framework_version)
return False


def _using_merged_images(region, framework, py_version, framework_version):
"""
Args:
region:
framework:
py_version:
accelerator_type:
framework_version:
"""
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
not_py2 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)

return (
((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION)
and is_merged_versions
# TODO: should be not mxnet-1.14.1-py2 instead?
and (
not_py2
or _is_tf_14_or_later(framework, framework_version)
or _is_pt_12_or_later(framework, framework_version)
or _is_mxnet_16_or_later(framework, framework_version)
)
)

def _use_dlc_image(region, framework, py_version, framework_version):
"""Return if the DLC image should be used for the given framework,
framework version, Python version, and region.

def _is_tf_14_or_later(framework, framework_version):
"""
Args:
framework:
framework_version:
"""
# Asimov team now owns Tensorflow 1.14.0 py2 and py3
asimov_lowest_tf_py2 = [1, 14, 0]
version = [int(s) for s in framework_version.split(".")]
return (
framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2[0 : len(version)]
)

region (str): The AWS region.
framework (str): The framework name, e.g. "tensorflow-scriptmode".
py_version (str): The Python version, e.g. "py3".
framework_version (str): The framework version.

def _is_pt_12_or_later(framework, framework_version):
"""
Args:
framework: Name of the frameowork
framework_version: framework version
Returns:
bool: Whether or not to use the corresponding DLC image.
"""
# Asimov team now owns PyTorch 1.2.0 py2 and py3
asimov_lowest_pt = [1, 2, 0]
version = [int(s) for s in framework_version.split(".")]
is_pytorch = framework in ("pytorch", "pytorch-serving")
return is_pytorch and version >= asimov_lowest_pt[0 : len(version)]
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_dlc_version = _is_dlc_version(framework, framework_version, py_version)


def _is_mxnet_16_or_later(framework, framework_version):
"""
Args:
framework: Name of the frameowork
framework_version: framework version
"""
# Asimov team now owns MXNet 1.6.0 py2 and py3
asimov_lowest_pt = [1, 6, 0]
version = [int(s) for s in framework_version.split(".")]
is_mxnet = framework in ("mxnet", "mxnet-serving")
return is_mxnet and version >= asimov_lowest_pt[0 : len(version)]
return ((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION) and is_dlc_version


def _registry_id(region, framework, py_version, account, framework_version):
"""
"""Return the Amazon ECR registry number (or AWS account ID) for
the given framework, framework version, Python version, and region.

Args:
region:
framework:
py_version:
account:
accelerator_type:
framework_version:
region (str): The AWS region.
framework (str): The framework name, e.g. "tensorflow-scriptmode".
py_version (str): The Python version, e.g. "py3".
account (str): The AWS account ID to use as a default.
framework_version (str): The framework version.

Returns:
str: The appropriate Amazon ECR registry number. If there is no
specific one for the framework, framework version, Python version,
and region, then ``account`` is returned.
"""
if _using_merged_images(region, framework, py_version, framework_version):
if _use_dlc_image(region, framework, py_version, framework_version):
if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION:
return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION.get(region)
if region in ASIMOV_VALID_ACCOUNTS_BY_REGION:
Expand All @@ -211,6 +179,7 @@ def create_image_uri(
optimized_families=None,
):
"""Return the ECR URI of an image.

Args:
region (str): AWS region where the image is uploaded.
framework (str): framework used by the image.
Expand All @@ -225,6 +194,7 @@ def create_image_uri(
accelerator_type (str): SageMaker Elastic Inference accelerator type.
optimized_families (str): Instance families for which there exist
specific optimized images.

Returns:
str: The appropriate image URI based on the given parameters.
"""
Expand All @@ -240,7 +210,7 @@ def create_image_uri(
):
framework += "-eia"

# Handle Account Number for Gov Cloud and frameworks with DLC merged images
# Handle account number for specific cases (e.g. GovCloud, opt-in regions, DLC images etc.)
if account is None:
account = _registry_id(
region=region,
Expand Down Expand Up @@ -271,18 +241,19 @@ def create_image_uri(
else:
device_type = "cpu"

using_merged_images = _using_merged_images(region, framework, py_version, framework_version)
use_dlc_image = _use_dlc_image(region, framework, py_version, framework_version)

if not py_version or (using_merged_images and framework == "tensorflow-serving-eia"):
if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia"):
tag = "{}-{}".format(framework_version, device_type)
else:
tag = "{}-{}-{}".format(framework_version, device_type, py_version)

if using_merged_images:
return "{}/{}:{}".format(
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
)
return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)
if use_dlc_image:
ecr_repo = MERGED_FRAMEWORKS_REPO_MAP[framework]
else:
ecr_repo = "sagemaker-{}".format(framework)

return "{}/{}:{}".format(get_ecr_image_uri_prefix(account, region), ecr_repo, tag)


def _accelerator_type_valid_for_framework(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MXNetModel(FrameworkModel):
"""An MXNet SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""

__framework_name__ = "mxnet"
_LOWEST_MMS_VERSION = "1.4"
_LOWEST_MMS_VERSION = "1.4.0"

def __init__(
self,
Expand Down