Skip to content

Commit 16768ff

Browse files
author
Kim
committed
change: blacklist unknown xgboost image versions
1 parent f14676a commit 16768ff

File tree

3 files changed

+96
-23
lines changed

3 files changed

+96
-23
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from sagemaker.session import s3_input
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
3030
from sagemaker.xgboost.defaults import (
31+
XGBOOST_1P_VERSIONS,
3132
XGBOOST_LATEST_VERSION,
33+
XGBOOST_NAME,
3234
XGBOOST_SUPPORTED_VERSIONS,
33-
XGBOOST_VERSION_0_90_1,
34-
XGBOOST_VERSION_0_90,
3535
XGBOOST_VERSION_EQUIVALENTS,
3636
)
3737
from sagemaker.xgboost.estimator import get_xgboost_image_uri
@@ -616,41 +616,67 @@ def get_image_uri(region_name, repo_name, repo_version=1):
616616
repo_name:
617617
repo_version:
618618
"""
619-
if repo_name == "xgboost":
620-
if not _is_latest_xgboost_version(repo_version):
621-
logging.warning(
622-
"There is a more up to date SageMaker XGBoost image. "
623-
"To use the newer image, please set 'repo_version'="
624-
"'%s'. For example:\n"
625-
"\tget_image_uri(region, 'xgboost', '%s').",
626-
XGBOOST_LATEST_VERSION,
627-
XGBOOST_LATEST_VERSION,
628-
)
629-
630-
if repo_version in [XGBOOST_VERSION_0_90] + _generate_version_equivalents(
631-
XGBOOST_VERSION_0_90_1
632-
):
633-
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_0_90_1)
634-
635-
supported_version = [
619+
repo_version = str(repo_version)
620+
621+
if repo_name == XGBOOST_NAME:
622+
623+
if repo_version in XGBOOST_1P_VERSIONS:
624+
_warn_newer_xgboost_image()
625+
return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version)
626+
627+
if "-" not in repo_version:
628+
xgboost_version_matches = [
629+
version
630+
for version in XGBOOST_SUPPORTED_VERSIONS
631+
if repo_version == version.split("-")[0]
632+
]
633+
if xgboost_version_matches:
634+
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest,
635+
# and the latest version is at the end of the list.
636+
repo_version = xgboost_version_matches[-1]
637+
638+
supported_framework_versions = [
636639
version
637640
for version in XGBOOST_SUPPORTED_VERSIONS
638641
if repo_version in _generate_version_equivalents(version)
639642
]
640-
if supported_version:
641-
return get_xgboost_image_uri(region_name, supported_version[0])
643+
644+
if not supported_framework_versions:
645+
raise ValueError(
646+
"SageMaker XGBoost version {} is not supported. Supported versions: {}".format(
647+
repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS)
648+
)
649+
)
650+
651+
if not _is_latest_xgboost_version(repo_version):
652+
_warn_newer_xgboost_image()
653+
654+
return get_xgboost_image_uri(region_name, supported_framework_versions[-1])
642655

643656
repo = "{}:{}".format(repo_name, repo_version)
644657
return "{}/{}".format(registry(region_name, repo_name), repo)
645658

646659

660+
def _warn_newer_xgboost_image():
661+
"""Print a warning when there is a newer XGBoost image"""
662+
logging.warning(
663+
"There is a more up to date SageMaker XGBoost image. "
664+
"To use the newer image, please set 'repo_version'="
665+
"'%s'. For example:\n"
666+
"\tget_image_uri(region, '%s', '%s').",
667+
XGBOOST_LATEST_VERSION,
668+
XGBOOST_NAME,
669+
XGBOOST_LATEST_VERSION,
670+
)
671+
672+
647673
def _is_latest_xgboost_version(repo_version):
648674
"""Compare xgboost image version with latest version
649675
650676
Args:
651677
repo_version:
652678
"""
653-
if repo_version in (1, "latest"):
679+
if repo_version in XGBOOST_1P_VERSIONS:
654680
return False
655681
return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION)
656682

src/sagemaker/xgboost/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
XGBOOST_NAME = "xgboost"
17+
XGBOOST_1P_VERSIONS = ["1", "latest"]
1718
XGBOOST_VERSION_0_90 = "0.90"
1819
XGBOOST_VERSION_0_90_1 = "0.90-1"
1920
XGBOOST_VERSION_0_90_2 = "0.90-2"
2021
XGBOOST_LATEST_VERSION = "1.0-1"
22+
# XGBOOST_SUPPORTED_VERSIONS has XGBoost Framework versions sorted from oldest to latest
2123
XGBOOST_SUPPORTED_VERSIONS = [
2224
XGBOOST_VERSION_0_90_1,
2325
XGBOOST_VERSION_0_90_2,

tests/unit/test_amazon_estimator.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ def test_file_system_record_set_data_channel():
452452
def test_get_xgboost_image_uri():
453453
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost")
454454
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
455+
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", 1)
456+
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1"
457+
legacy_xgb_image_uri = get_image_uri(REGION, "xgboost", "latest")
458+
assert legacy_xgb_image_uri == "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"
455459

456460
updated_xgb_image_uri = get_image_uri(REGION, "xgboost", "0.90-1")
457461
assert (
@@ -465,6 +469,47 @@ def test_get_xgboost_image_uri():
465469
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
466470
)
467471

472+
assert (
473+
get_image_uri(REGION, "xgboost", "0.90-2-cpu-py3")
474+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
475+
)
476+
assert (
477+
get_image_uri(REGION, "xgboost", "0.90")
478+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
479+
)
480+
assert (
481+
get_image_uri(REGION, "xgboost", "1.0-1")
482+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
483+
)
484+
assert (
485+
get_image_uri(REGION, "xgboost", "1.0-1-cpu-py3")
486+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
487+
)
488+
assert (
489+
get_image_uri(REGION, "xgboost", "1.0")
490+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3"
491+
)
492+
493+
494+
def test_get_xgboost_image_uri_warning_with_legacy(caplog):
495+
get_image_uri(REGION, "xgboost", 1)
496+
assert "There is a more up to date SageMaker XGBoost image." in caplog.text
497+
498+
499+
def test_get_xgboost_image_uri_no_warning_with_latest(caplog):
500+
get_image_uri(REGION, "xgboost", XGBOOST_LATEST_VERSION.split("-")[0])
501+
assert "There is a more up to date SageMaker XGBoost image." not in caplog.text
502+
503+
504+
def test_get_xgboost_image_uri_throws_error_for_unsupported_version():
505+
with pytest.raises(ValueError) as error:
506+
get_image_uri(REGION, "xgboost", "99.99-9")
507+
assert "SageMaker XGBoost version 99.99-9 is not supported" in str(error)
508+
509+
with pytest.raises(ValueError) as error:
510+
get_image_uri(REGION, "xgboost", "0.90-1-gpu-py3")
511+
assert "SageMaker XGBoost version 0.90-1-gpu-py3 is not supported" in str(error)
512+
468513

469514
def test_regitry_throws_error_if_mapping_does_not_exist_for_lda():
470515
with pytest.raises(ValueError) as error:
@@ -484,5 +529,5 @@ def test_is_latest_xgboost_version():
484529
assert _is_latest_xgboost_version(version) is False
485530

486531
assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False
487-
532+
assert _is_latest_xgboost_version("0.90-2-cpu-py3") is False
488533
assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True

0 commit comments

Comments
 (0)