Skip to content

Commit 5bdd3a5

Browse files
author
Mark Bunday
committed
fix: Return ARM XGB/SKLearn tags if image_scope is inference_graviton
1 parent 885423c commit 5bdd3a5

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

src/sagemaker/image_uris.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
XGBOOST_FRAMEWORK = "xgboost"
3535
SKLEARN_FRAMEWORK = "sklearn"
3636
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
37+
INFERENCE_GRAVITON = "inference_graviton"
3738

3839

3940
@override_pipeline_parameter_var
@@ -146,8 +147,9 @@ def retrieve(
146147
)
147148

148149
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
150+
final_image_scope = image_scope
149151
config = _config_for_framework_and_scope(
150-
framework + "-training-compiler", image_scope, accelerator_type
152+
framework + "-training-compiler", final_image_scope, accelerator_type
151153
)
152154
else:
153155
_framework = framework
@@ -234,6 +236,7 @@ def retrieve(
234236
tag = _get_image_tag(
235237
container_version,
236238
distribution,
239+
final_image_scope,
237240
framework,
238241
inference_tool,
239242
instance_type,
@@ -266,6 +269,7 @@ def _get_instance_type_family(instance_type):
266269
def _get_image_tag(
267270
container_version,
268271
distribution,
272+
final_image_scope,
269273
framework,
270274
inference_tool,
271275
instance_type,
@@ -276,9 +280,9 @@ def _get_image_tag(
276280
):
277281
"""Return image tag based on framework, container, and compute configuration(s)."""
278282
instance_type_family = _get_instance_type_family(instance_type)
279-
if (
280-
framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK)
281-
and instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
283+
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK) and (
284+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
285+
or final_image_scope == INFERENCE_GRAVITON
282286
):
283287
version_to_arm64_tag_mapping = {
284288
"xgboost": {
@@ -375,7 +379,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
375379
framework in GRAVITON_ALLOWED_FRAMEWORKS
376380
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
377381
):
378-
return "inference_graviton"
382+
return INFERENCE_GRAVITON
379383
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
380384
# Preserves backwards compatibility with XGB/SKLearn configs which no
381385
# longer define top-level "scope" keys after introducing support for

tests/unit/sagemaker/image_uris/test_graviton.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
8989
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
9090

9191

92-
def test_graviton_xgboost(graviton_xgboost_versions):
92+
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
9393
for xgboost_version in graviton_xgboost_versions:
9494
for instance_type in GRAVITON_INSTANCE_TYPES:
9595
uri = image_uris.retrieve(
@@ -102,6 +102,19 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102102
assert expected == uri
103103

104104

105+
def test_graviton_xgboost_image_scope_specified(graviton_xgboost_versions):
106+
for xgboost_version in graviton_xgboost_versions:
107+
for instance_type in GRAVITON_INSTANCE_TYPES:
108+
uri = image_uris.retrieve(
109+
"xgboost", "us-west-2", version=xgboost_version, image_scope="inference_graviton"
110+
)
111+
expected = (
112+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113+
f"{xgboost_version}-arm64"
114+
)
115+
assert expected == uri
116+
117+
105118
def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versions):
106119
for xgboost_version in graviton_xgboost_unsupported_versions:
107120
for instance_type in GRAVITON_INSTANCE_TYPES:
@@ -112,7 +125,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112125
assert f"Unsupported xgboost version: {xgboost_version}." in str(error)
113126

114127

115-
def test_graviton_sklearn(graviton_sklearn_versions):
128+
def test_graviton_sklearn_instance_type_specified(graviton_sklearn_versions):
116129
for sklearn_version in graviton_sklearn_versions:
117130
for instance_type in GRAVITON_INSTANCE_TYPES:
118131
uri = image_uris.retrieve(
@@ -125,6 +138,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125138
assert expected == uri
126139

127140

141+
def test_graviton_sklearn_image_scope_specified(graviton_sklearn_versions):
142+
for sklearn_version in graviton_sklearn_versions:
143+
for instance_type in GRAVITON_INSTANCE_TYPES:
144+
uri = image_uris.retrieve(
145+
"sklearn", "us-west-2", version=sklearn_version, image_scope="inference_graviton"
146+
)
147+
expected = (
148+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
149+
f"{sklearn_version}-arm64-cpu-py3"
150+
)
151+
assert expected == uri
152+
153+
128154
def test_graviton_sklearn_unsupported_version(graviton_sklearn_unsupported_versions):
129155
for sklearn_version in graviton_sklearn_unsupported_versions:
130156
for instance_type in GRAVITON_INSTANCE_TYPES:

0 commit comments

Comments
 (0)