Skip to content

Commit 20fc81a

Browse files
author
Kim
committed
Addressed comments
1 parent 69e6fe0 commit 20fc81a

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,10 @@ def get_image_uri(region_name, repo_name, repo_version=1):
636636
if repo_version == version.split("-")[0]
637637
]
638638
if xgboost_version_matches:
639-
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest,
640-
# and the latest version is at the end of the list.
641-
repo_version = xgboost_version_matches[-1]
639+
# Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest.
640+
# When SageMaker version is not specified, we use the oldest one that matches
641+
# XGBoost version for backward compatibility.
642+
repo_version = xgboost_version_matches[0]
642643

643644
supported_framework_versions = [
644645
version

tests/unit/test_amazon_estimator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def test_get_xgboost_image_uri():
475475
)
476476
assert (
477477
get_image_uri(REGION, "xgboost", "0.90")
478-
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
478+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
479479
)
480480
assert (
481481
get_image_uri(REGION, "xgboost", "1.0-1")
@@ -527,10 +527,8 @@ def test_is_latest_xgboost_version():
527527
for version in XGBOOST_SUPPORTED_VERSIONS:
528528
if version != XGBOOST_LATEST_VERSION:
529529
assert _is_latest_xgboost_version(version) is False
530-
531-
assert _is_latest_xgboost_version("0.90-1-cpu-py3") is False
532-
assert _is_latest_xgboost_version("0.90-2-cpu-py3") is False
533-
assert _is_latest_xgboost_version(XGBOOST_LATEST_VERSION) is True
530+
else:
531+
assert _is_latest_xgboost_version(version) is True
534532

535533

536534
def test_get_image_uri_warn(caplog):

0 commit comments

Comments
 (0)