Skip to content

feature: handle separate training/inference images and EI in image_uris.retrieve #1707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/image_uri_config/chainer.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"processors": ["cpu", "gpu"],
"scope": ["inference", "training"],
"version_aliases": {
"4.0": "4.0.0",
"4.1": "4.1.0",
Expand Down
1,207 changes: 1,207 additions & 0 deletions src/sagemaker/image_uri_config/tensorflow.json

Large diffs are not rendered by default.

108 changes: 70 additions & 38 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,25 @@
from __future__ import absolute_import

import json
import logging
import os

from sagemaker import utils

logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are difference in logger definitions in various parts of the code. What is your vision on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I had a vision for this - just went with whatever looked reasonable after doing a grep -r "logger =" src 😅


ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"


def retrieve(framework, region, version=None, py_version=None, instance_type=None):
def retrieve(
framework,
region,
version=None,
py_version=None,
instance_type=None,
accelerator_type=None,
image_scope=None,
):
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Args:
Expand All @@ -34,28 +45,48 @@ def retrieve(framework, region, version=None, py_version=None, instance_type=Non
instance_type (str): The SageMaker instance type. For supported types, see
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
there are different images for different processor types.
accelerator_type (str): Elastic Inference accelerator type. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the framework version, Python version, processor type, or region is
not supported given the other arguments.
ValueError: If the combination of arguments specified is not supported.
"""
config = config_for_framework(framework)
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
version_config = config["versions"][_version_for_config(version, config, framework)]

py_version = _validate_py_version_and_set_if_needed(py_version, version_config)
version_config = version_config.get(py_version) or version_config

registry = _registry_from_region(region, version_config["registries"])
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]

repo = version_config["repository"]

_validate_py_version(py_version, version_config["py_versions"], framework, version)
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version)
tag = _format_tag(version, _processor(instance_type, config["processors"]), py_version)

return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)


def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None):
"""Loads the JSON config for the given framework and image scope."""
config = config_for_framework(framework)

if accelerator_type:
if image_scope not in ("eia", "inference"):
logger.warning(
"Elastic inference is for inference only. Ignoring image scope: %s.", image_scope
)
image_scope = "eia"

_validate_arg("image scope", image_scope, config.get("scope", config.keys()))
return config if "scope" in config else config[image_scope]


def config_for_framework(framework):
"""Loads the JSON config for the given framework."""
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
Expand All @@ -69,27 +100,13 @@ def _version_for_config(version, config, framework):
if version in config["version_aliases"].keys():
return config["version_aliases"][version]

available_versions = config["versions"].keys()
if version in available_versions:
return version

raise ValueError(
"Unsupported {} version: {}. "
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions))
)
_validate_arg("{} version".format(framework), version, config["versions"].keys())
return version


def _registry_from_region(region, registry_dict):
"""Returns the ECR registry (AWS account number) for the given region."""
available_regions = registry_dict.keys()
if region not in available_regions:
raise ValueError(
"Unsupported region: {}. You may need to upgrade "
"your SDK version (pip install -U sagemaker) for newer regions. "
"Supported region(s): {}.".format(region, ", ".join(available_regions))
)

_validate_arg("region", region, registry_dict.keys())
return registry_dict[region]


Expand All @@ -106,22 +123,37 @@ def _processor(instance_type, available_processors):
family = instance_type.split(".")[1]
processor = "gpu" if family[0] in ("g", "p") else "cpu"

if processor in available_processors:
return processor

raise ValueError(
"Unsupported processor type: {} (for {}). "
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors))
)
_validate_arg("processor", processor, available_processors)
return processor


def _validate_py_version(py_version, available_versions, framework, fw_version):
def _validate_py_version_and_set_if_needed(py_version, version_config):
"""Checks if the Python version is one of the supported versions."""
if py_version not in available_versions:
available_versions = version_config.get("py_versions", version_config.keys())

if len(available_versions) == 0:
if py_version:
logger.info("Ignoring unnecessary Python version: %s.", py_version)
return None

if py_version is None and len(available_versions) == 1:
logger.info("Defaulting to only available Python version: %s", available_versions[0])
return available_versions[0]

_validate_arg("Python version", py_version, available_versions)
return py_version


def _validate_arg(arg_name, arg, available_options):
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
if arg not in available_options:
raise ValueError(
"Unsupported Python version for {} {}: {}. You may need to upgrade "
"your SDK version (pip install -U sagemaker) for newer versions. "
"Supported Python version(s): {}.".format(
framework, fw_version, py_version, ", ".join(available_versions)
)
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options))
)


def _format_tag(version, processor, py_version):
"""Creates a tag for the image URI."""
return "-".join([x for x in (version, processor, py_version) if x])
77 changes: 29 additions & 48 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,50 +164,20 @@ def xgboost_version(request):
return request.param


@pytest.fixture(
scope="module",
params=[
"1.4",
"1.4.1",
"1.5",
"1.5.0",
"1.6",
"1.6.0",
"1.7",
"1.7.0",
"1.8",
"1.8.0",
"1.9",
"1.9.0",
"1.10",
"1.10.0",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.13",
"1.14",
"1.14.0",
"1.15",
"1.15.0",
"1.15.2",
"2.0",
"2.0.0",
"2.0.1",
"2.1",
"2.1.0",
],
)
def tf_version(request):
return request.param
@pytest.fixture(scope="module")
def tf_version(tensorflow_training_version):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be a tensorflow_serving_version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR creates tensorflow_training_version, tensorflow_inference_version, and tensorflow_ei_version. I'll fix all the fixture usage in the tests in a separate PR - I was worried that this PR had too many changes already 😂

# TODO: remove this fixture and update tests
if tensorflow_training_version in ("1.13.1", "2.2", "2.2.0"):
pytest.skip("version isn't compatible with both training and inference.")
return tensorflow_training_version


@pytest.fixture(scope="module", params=["py2", "py3"])
def tf_py_version(tf_version, request):
version = [int(val) for val in tf_version.split(".")]
if version < [1, 11]:
def tf_py_version(tensorflow_training_version, request):
version = Version(tensorflow_training_version)
if version < Version("1.11"):
return "py2"
if version < [2, 2]:
if version < Version("2.2"):
return request.param
return "py37"

Expand Down Expand Up @@ -401,11 +371,22 @@ def pytest_generate_tests(metafunc):
params.append("ml.p2.xlarge")
metafunc.parametrize("instance_type", params, scope="session")

for fw in ("chainer",):
fixture_name = "{}_version".format(fw)
if fixture_name in metafunc.fixturenames:
config = image_uris.config_for_framework(fw)
versions = list(config["versions"].keys()) + list(
config.get("version_aliases", {}).keys()
)
metafunc.parametrize(fixture_name, versions, scope="session")
_generate_all_framework_version_fixtures(metafunc)


def _generate_all_framework_version_fixtures(metafunc):
for fw in ("chainer", "tensorflow"):
config = image_uris.config_for_framework(fw)
if "scope" in config:
_parametrize_framework_version_fixture(metafunc, "{}_version".format(fw), config)
else:
for image_scope in config.keys():
_parametrize_framework_version_fixture(
metafunc, "{}_{}_version".format(fw, image_scope), config[image_scope]
)


def _parametrize_framework_version_fixture(metafunc, fixture_name, config):
if fixture_name in metafunc.fixturenames:
versions = list(config["versions"].keys()) + list(config.get("version_aliases", {}).keys())
metafunc.parametrize(fixture_name, versions, scope="session")
Loading