Skip to content

Commit 31df3c6

Browse files
aws-patlinyuanzhua
authored andcommitted
Add support for xgboost version 0.90-2 (aws#258)
* change: add support for xgboost version 0.90-2 * fix: remove duplicate rule in test_debugger.py:test_mxnet_with_all_rules_and_configs
1 parent a3c7d93 commit 31df3c6

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from sagemaker.model import NEO_IMAGE_ACCOUNT
2828
from sagemaker.session import s3_input
2929
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
30+
from sagemaker.xgboost.defaults import XGBOOST_VERSION_1, XGBOOST_SUPPORTED_VERSIONS
3031
from sagemaker.xgboost.estimator import get_xgboost_image_uri
31-
from sagemaker.xgboost.defaults import XGBOOST_LATEST_VERSION
3232

3333
logger = logging.getLogger(__name__)
3434

@@ -559,13 +559,23 @@ def get_image_uri(region_name, repo_name, repo_version=1):
559559
"""
560560
if repo_name == "xgboost":
561561
if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]:
562-
return get_xgboost_image_uri(region_name, XGBOOST_LATEST_VERSION)
562+
return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1)
563+
564+
supported_version = [
565+
version
566+
for version in XGBOOST_SUPPORTED_VERSIONS
567+
if repo_version in (version, version + "-cpu-py3")
568+
]
569+
if supported_version:
570+
return get_xgboost_image_uri(region_name, supported_version[0])
571+
563572
logging.warning(
564-
"There is a more up to date SageMaker XGBoost image."
573+
"There is a more up to date SageMaker XGBoost image. "
565574
"To use the newer image, please set 'repo_version'="
566-
"'0.90-1. For example:\n"
575+
"'%s'. For example:\n"
567576
"\tget_image_uri(region, 'xgboost', '%s').",
568-
XGBOOST_LATEST_VERSION,
577+
XGBOOST_VERSION_1,
578+
XGBOOST_VERSION_1,
569579
)
570580
repo = "{}:{}".format(repo_name, repo_version)
571581
return "{}/{}".format(registry(region_name, repo_name), repo)

src/sagemaker/xgboost/defaults.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
XGBOOST_NAME = "xgboost"
17-
XGBOOST_LATEST_VERSION = "0.90-1"
18-
XGBOOST_SUPPORTED_VERSIONS = [XGBOOST_LATEST_VERSION]
17+
XGBOOST_VERSION_1 = "0.90-1"
18+
XGBOOST_LATEST_VERSION = "0.90-2"
19+
XGBOOST_SUPPORTED_VERSIONS = [XGBOOST_VERSION_1, XGBOOST_LATEST_VERSION]

tests/integ/test_debugger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@
345345
# rules = [
346346
# Rule.sagemaker(rule_configs.vanishing_gradient()),
347347
# Rule.sagemaker(rule_configs.all_zero()),
348-
# Rule.sagemaker(rule_configs.check_input_images()),
349348
# Rule.sagemaker(rule_configs.similar_across_runs()),
350349
# Rule.sagemaker(rule_configs.weight_update_ratio()),
351350
# Rule.sagemaker(rule_configs.exploding_tensor()),

tests/unit/test_amazon_estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,9 @@ def test_get_xgboost_image_uri():
423423
updated_xgb_image_uri
424424
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3"
425425
)
426+
427+
updated_xgb_image_uri_v2 = get_image_uri(REGION, "xgboost", "0.90-2")
428+
assert (
429+
updated_xgb_image_uri_v2
430+
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
431+
)

tests/unit/test_xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_create_model(sagemaker_session):
175175
entry_point=SCRIPT_PATH,
176176
framework_version=XGBOOST_LATEST_VERSION,
177177
)
178-
default_image_uri = _get_full_cpu_image_uri("0.90-1")
178+
default_image_uri = _get_full_cpu_image_uri(XGBOOST_LATEST_VERSION)
179179
model_values = xgboost_model.prepare_container_def(CPU)
180180
assert model_values["Image"] == default_image_uri
181181

0 commit comments

Comments
 (0)