Skip to content

fix: Return ARM XGB/SKLearn tags if image_scope is inference_graviton #3449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
XGBOOST_FRAMEWORK = "xgboost"
SKLEARN_FRAMEWORK = "sklearn"
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
INFERENCE_GRAVITON = "inference_graviton"


@override_pipeline_parameter_var
Expand Down Expand Up @@ -75,8 +76,8 @@ def retrieve(
accelerator_type (str): Elastic Inference accelerator type. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
Valid values: "training", "inference", "inference_graviton", "eia".
If ``accelerator_type`` is set, ``image_scope`` is ignored.
container_version (str): the version of docker image.
Ideally the value of parameter should be created inside the framework.
For custom use, see the list of supported container versions:
Expand Down Expand Up @@ -146,8 +147,9 @@ def retrieve(
)

if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
final_image_scope = image_scope
config = _config_for_framework_and_scope(
framework + "-training-compiler", image_scope, accelerator_type
framework + "-training-compiler", final_image_scope, accelerator_type
)
else:
_framework = framework
Expand Down Expand Up @@ -234,6 +236,7 @@ def retrieve(
tag = _get_image_tag(
container_version,
distribution,
final_image_scope,
framework,
inference_tool,
instance_type,
Expand Down Expand Up @@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
def _get_image_tag(
container_version,
distribution,
final_image_scope,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also update the docs in line 79? I'll ping the SDK team on Slack about updating the public docs.

framework,
inference_tool,
instance_type,
Expand All @@ -276,20 +280,29 @@ def _get_image_tag(
):
"""Return image tag based on framework, container, and compute configuration(s)."""
instance_type_family = _get_instance_type_family(instance_type)
if (
framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK)
and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
):
version_to_arm64_tag_mapping = {
"xgboost": {
"1.5-1": "1.5-1-arm64",
"1.3-1": "1.3-1-arm64",
},
"sklearn": {
"1.0-1": "1.0-1-arm64-cpu-py3",
},
}
tag = version_to_arm64_tag_mapping[framework][version]
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
_validate_arg(
instance_type_family,
GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY,
"instance type",
)
if (
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
or final_image_scope == INFERENCE_GRAVITON
):
version_to_arm64_tag_mapping = {
"xgboost": {
"1.5-1": "1.5-1-arm64",
"1.3-1": "1.3-1-arm64",
},
"sklearn": {
"1.0-1": "1.0-1-arm64-cpu-py3",
},
}
tag = version_to_arm64_tag_mapping[framework][version]
else:
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
else:
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)

Expand Down Expand Up @@ -375,7 +388,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
framework in GRAVITON_ALLOWED_FRAMEWORKS
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
):
return "inference_graviton"
return INFERENCE_GRAVITON
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
# Preserves backwards compatibility with XGB/SKLearn configs which no
# longer define top-level "scope" keys after introducing support for
Expand Down
58 changes: 56 additions & 2 deletions tests/unit/sagemaker/image_uris/test_graviton.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)


def test_graviton_xgboost(graviton_xgboost_versions):
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
for xgboost_version in graviton_xgboost_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
uri = image_uris.retrieve(
Expand All @@ -102,6 +102,33 @@ def test_graviton_xgboost(graviton_xgboost_versions):
assert expected == uri


def test_graviton_xgboost_image_scope_specified(graviton_xgboost_versions):
for xgboost_version in graviton_xgboost_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
uri = image_uris.retrieve(
"xgboost", "us-west-2", version=xgboost_version, image_scope="inference_graviton"
)
expected = (
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
f"{xgboost_version}-arm64"
)
assert expected == uri


def test_graviton_xgboost_image_scope_specified_x86_instance(graviton_xgboost_versions):
for xgboost_version in graviton_xgboost_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
with pytest.raises(ValueError) as error:
image_uris.retrieve(
"xgboost",
"us-west-2",
version=xgboost_version,
image_scope="inference_graviton",
instance_type="ml.m5.xlarge",
)
assert "Unsupported instance type: m5." in str(error)


def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versions):
for xgboost_version in graviton_xgboost_unsupported_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
Expand All @@ -112,7 +139,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
assert f"Unsupported xgboost version: {xgboost_version}." in str(error)


def test_graviton_sklearn(graviton_sklearn_versions):
def test_graviton_sklearn_instance_type_specified(graviton_sklearn_versions):
for sklearn_version in graviton_sklearn_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
uri = image_uris.retrieve(
Expand All @@ -125,6 +152,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
assert expected == uri


def test_graviton_sklearn_image_scope_specified(graviton_sklearn_versions):
for sklearn_version in graviton_sklearn_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
uri = image_uris.retrieve(
"sklearn", "us-west-2", version=sklearn_version, image_scope="inference_graviton"
)
expected = (
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
f"{sklearn_version}-arm64-cpu-py3"
)
assert expected == uri


def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versions):
for sklearn_version in graviton_sklearn_unsupported_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
Expand All @@ -138,6 +178,20 @@ def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versi
assert expected == uri


def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_unsupported_versions):
for sklearn_version in graviton_sklearn_unsupported_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
with pytest.raises(ValueError) as error:
image_uris.retrieve(
"sklearn",
"us-west-2",
version=sklearn_version,
image_scope="inference_graviton",
instance_type="ml.m5.xlarge",
)
assert "Unsupported instance type: m5." in str(error)


def _expected_graviton_framework_uri(framework, version, region):
return expected_uris.graviton_framework_uri(
"{}-inference-graviton".format(framework),
Expand Down