-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: handle separate training/inference images and EI in image_uris.retrieve #1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,25 @@ | |
from __future__ import absolute_import | ||
|
||
import json | ||
import logging | ||
import os | ||
|
||
from sagemaker import utils | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}" | ||
|
||
|
||
def retrieve(framework, region, version=None, py_version=None, instance_type=None): | ||
def retrieve( | ||
framework, | ||
region, | ||
version=None, | ||
py_version=None, | ||
instance_type=None, | ||
accelerator_type=None, | ||
image_scope=None, | ||
): | ||
"""Retrieves the ECR URI for the Docker image matching the given arguments. | ||
|
||
Args: | ||
|
@@ -34,28 +45,48 @@ def retrieve(framework, region, version=None, py_version=None, instance_type=Non | |
instance_type (str): The SageMaker instance type. For supported types, see | ||
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if | ||
there are different images for different processor types. | ||
accelerator_type (str): Elastic Inference accelerator type. For more, see | ||
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. | ||
image_scope (str): The image type, i.e. what it is used for. | ||
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, | ||
``image_scope`` is ignored. | ||
|
||
Returns: | ||
str: the ECR URI for the corresponding SageMaker Docker image. | ||
|
||
Raises: | ||
ValueError: If the framework version, Python version, processor type, or region is | ||
not supported given the other arguments. | ||
ValueError: If the combination of arguments specified is not supported. | ||
""" | ||
config = config_for_framework(framework) | ||
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type) | ||
version_config = config["versions"][_version_for_config(version, config, framework)] | ||
|
||
py_version = _validate_py_version_and_set_if_needed(py_version, version_config) | ||
version_config = version_config.get(py_version) or version_config | ||
|
||
registry = _registry_from_region(region, version_config["registries"]) | ||
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"] | ||
|
||
repo = version_config["repository"] | ||
|
||
_validate_py_version(py_version, version_config["py_versions"], framework, version) | ||
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version) | ||
tag = _format_tag(version, _processor(instance_type, config["processors"]), py_version) | ||
|
||
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag) | ||
|
||
|
||
def _config_for_framework_and_scope(framework, image_scope, accelerator_type=None): | ||
"""Loads the JSON config for the given framework and image scope.""" | ||
config = config_for_framework(framework) | ||
|
||
if accelerator_type: | ||
if image_scope not in ("eia", "inference"): | ||
logger.info( | ||
metrizable marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Elastic inference is for inference only. Ignoring image scope: %s.", image_scope | ||
) | ||
image_scope = "eia" | ||
|
||
_validate_arg("image scope", image_scope, config.get("scope", config.keys())) | ||
return config if "scope" in config else config[image_scope] | ||
|
||
|
||
def config_for_framework(framework): | ||
"""Loads the JSON config for the given framework.""" | ||
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework)) | ||
|
@@ -69,27 +100,13 @@ def _version_for_config(version, config, framework): | |
if version in config["version_aliases"].keys(): | ||
return config["version_aliases"][version] | ||
|
||
available_versions = config["versions"].keys() | ||
if version in available_versions: | ||
return version | ||
|
||
raise ValueError( | ||
"Unsupported {} version: {}. " | ||
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. " | ||
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions)) | ||
) | ||
_validate_arg("{} version".format(framework), version, config["versions"].keys()) | ||
return version | ||
|
||
|
||
def _registry_from_region(region, registry_dict): | ||
"""Returns the ECR registry (AWS account number) for the given region.""" | ||
available_regions = registry_dict.keys() | ||
if region not in available_regions: | ||
raise ValueError( | ||
"Unsupported region: {}. You may need to upgrade " | ||
"your SDK version (pip install -U sagemaker) for newer regions. " | ||
"Supported region(s): {}.".format(region, ", ".join(available_regions)) | ||
) | ||
|
||
_validate_arg("region", region, registry_dict.keys()) | ||
return registry_dict[region] | ||
|
||
|
||
|
@@ -106,22 +123,41 @@ def _processor(instance_type, available_processors): | |
family = instance_type.split(".")[1] | ||
processor = "gpu" if family[0] in ("g", "p") else "cpu" | ||
|
||
if processor in available_processors: | ||
return processor | ||
|
||
raise ValueError( | ||
"Unsupported processor type: {} (for {}). " | ||
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors)) | ||
) | ||
_validate_arg("processor", processor, available_processors) | ||
return processor | ||
|
||
|
||
def _validate_py_version(py_version, available_versions, framework, fw_version): | ||
def _validate_py_version_and_set_if_needed(py_version, version_config): | ||
"""Checks if the Python version is one of the supported versions.""" | ||
if py_version not in available_versions: | ||
available_versions = version_config.get("py_versions", version_config.keys()) | ||
|
||
if len(available_versions) == 0: | ||
if py_version: | ||
logger.info("Ignoring unnecessary Python version: %s.", py_version) | ||
return None | ||
|
||
if py_version is None and len(available_versions) == 1: | ||
logger.info("Defaulting to only available Python version: %s", available_versions[0]) | ||
return available_versions[0] | ||
|
||
_validate_arg("Python version", py_version, available_versions) | ||
return py_version | ||
|
||
|
||
def _validate_arg(arg_name, arg, available_options): | ||
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not.""" | ||
if arg not in available_options: | ||
raise ValueError( | ||
"Unsupported Python version for {} {}: {}. You may need to upgrade " | ||
"your SDK version (pip install -U sagemaker) for newer versions. " | ||
"Supported Python version(s): {}.".format( | ||
framework, fw_version, py_version, ", ".join(available_versions) | ||
) | ||
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version " | ||
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): " | ||
"{options}.".format(arg_name=arg_name, arg=arg, options=", ".join(available_options)) | ||
) | ||
|
||
|
||
def _format_tag(version, processor, py_version): | ||
"""Creates a tag for the image URI.""" | ||
tag = "{}-{}".format(version, processor) | ||
if py_version: | ||
tag += "-{}".format(py_version) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. re: this, unify the way you do this? |
||
|
||
return tag |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,50 +164,20 @@ def xgboost_version(request): | |
return request.param | ||
|
||
|
||
@pytest.fixture( | ||
scope="module", | ||
params=[ | ||
"1.4", | ||
"1.4.1", | ||
"1.5", | ||
"1.5.0", | ||
"1.6", | ||
"1.6.0", | ||
"1.7", | ||
"1.7.0", | ||
"1.8", | ||
"1.8.0", | ||
"1.9", | ||
"1.9.0", | ||
"1.10", | ||
"1.10.0", | ||
"1.11", | ||
"1.11.0", | ||
"1.12", | ||
"1.12.0", | ||
"1.13", | ||
"1.14", | ||
"1.14.0", | ||
"1.15", | ||
"1.15.0", | ||
"1.15.2", | ||
"2.0", | ||
"2.0.0", | ||
"2.0.1", | ||
"2.1", | ||
"2.1.0", | ||
], | ||
) | ||
def tf_version(request): | ||
return request.param | ||
@pytest.fixture(scope="module") | ||
def tf_version(tensorflow_training_version): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will there be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this PR creates |
||
# TODO: remove this fixture and update tests | ||
if tensorflow_training_version in ("1.13.1", "2.2", "2.2.0"): | ||
pytest.skip("version isn't compatible with both training and inference.") | ||
return tensorflow_training_version | ||
|
||
|
||
@pytest.fixture(scope="module", params=["py2", "py3"]) | ||
def tf_py_version(tf_version, request): | ||
version = [int(val) for val in tf_version.split(".")] | ||
if version < [1, 11]: | ||
def tf_py_version(tensorflow_training_version, request): | ||
version = Version(tensorflow_training_version) | ||
if version < Version("1.11"): | ||
return "py2" | ||
if version < [2, 2]: | ||
if version < Version("2.2"): | ||
return request.param | ||
return "py37" | ||
|
||
|
@@ -401,11 +371,22 @@ def pytest_generate_tests(metafunc): | |
params.append("ml.p2.xlarge") | ||
metafunc.parametrize("instance_type", params, scope="session") | ||
|
||
for fw in ("chainer",): | ||
fixture_name = "{}_version".format(fw) | ||
if fixture_name in metafunc.fixturenames: | ||
config = image_uris.config_for_framework(fw) | ||
versions = list(config["versions"].keys()) + list( | ||
config.get("version_aliases", {}).keys() | ||
) | ||
metafunc.parametrize(fixture_name, versions, scope="session") | ||
_generate_all_framework_version_fixtures(metafunc) | ||
|
||
|
||
def _generate_all_framework_version_fixtures(metafunc): | ||
for fw in ("chainer", "tensorflow"): | ||
config = image_uris.config_for_framework(fw) | ||
if "scope" in config: | ||
_parametrize_framework_version_fixture(metafunc, "{}_version".format(fw), config) | ||
else: | ||
for image_scope in config.keys(): | ||
_parametrize_framework_version_fixture( | ||
metafunc, "{}_{}_version".format(fw, image_scope), config[image_scope] | ||
) | ||
|
||
|
||
def _parametrize_framework_version_fixture(metafunc, fixture_name, config): | ||
if fixture_name in metafunc.fixturenames: | ||
versions = list(config["versions"].keys()) + list(config.get("version_aliases", {}).keys()) | ||
metafunc.parametrize(fixture_name, versions, scope="session") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are difference in
logger
definitions in various parts of the code. What is your vision on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I had a vision for this - just went with whatever looked reasonable after doing a
grep -r "logger =" src
😅