Skip to content

Commit 74ea308

Browse files
Edward J Kimchuyang-dengicywang86rui
authored
feature: add xgboost framework version 1.2-1 (#1899)
* feature: add xgboost framework version 1.2-1 * Add copywrite to xgboost utils file * Add absolute_import to xgboost utils file * Remove xgboost_latest_py_version fixture from airflow config test Co-authored-by: Chuyang <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent b3f8abc commit 74ea308

File tree

11 files changed

+204
-28
lines changed

11 files changed

+204
-28
lines changed

src/sagemaker/image_uri_config/xgboost.json

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,35 @@
125125
"us-west-2": "246618743249"
126126
},
127127
"repository": "sagemaker-xgboost"
128+
},
129+
"1.2-1": {
130+
"registries": {
131+
"af-south-1": "510948584623",
132+
"ap-east-1": "651117190479",
133+
"ap-northeast-1": "354813040037",
134+
"ap-northeast-2": "366743142698",
135+
"ap-south-1": "720646828776",
136+
"ap-southeast-1": "121021644041",
137+
"ap-southeast-2": "783357654285",
138+
"ca-central-1": "341280168497",
139+
"cn-north-1": "450853457545",
140+
"cn-northwest-1": "451049120500",
141+
"eu-central-1": "492215442770",
142+
"eu-north-1": "662702820516",
143+
"eu-west-1": "141502667606",
144+
"eu-west-2": "764974769150",
145+
"eu-west-3": "659782779980",
146+
"eu-south-1": "978288397137",
147+
"me-south-1": "801668240914",
148+
"sa-east-1": "737474898029",
149+
"us-east-1": "683313688378",
150+
"us-east-2": "257758044811",
151+
"us-gov-west-1": "414596584902",
152+
"us-iso-east-1": "833128469047",
153+
"us-west-1": "746614075791",
154+
"us-west-2": "246618743249"
155+
},
156+
"repository": "sagemaker-xgboost"
128157
}
129158
}
130159
}

src/sagemaker/image_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,4 @@ def _validate_arg(arg, available_options, arg_name):
252252

253253
def _format_tag(tag_prefix, processor, py_version, container_version):
254254
"""Creates a tag for the image URI."""
255-
return "-".join([x for x in (tag_prefix, processor, py_version, container_version) if x])
255+
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)

src/sagemaker/xgboost/defaults.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@
1414
from __future__ import absolute_import
1515

1616
XGBOOST_NAME = "xgboost"
17+
XGBOOST_UNSUPPORTED_VERSIONS = {
18+
"1.1": (
19+
"XGBoost 1.1 is not supported on SageMaker because XGBoost 1.1 has broken capability to "
20+
"run prediction when the test input has fewer features than the training data in LIBSVM "
21+
"inputs. This capability has been restored in XGBoost 1.2 "
22+
"(https://github.com/dmlc/xgboost/pull/5955). Consider using SageMaker XGBoost 1.2-1."
23+
),
24+
}

src/sagemaker/xgboost/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2727
from sagemaker.xgboost import defaults
2828
from sagemaker.xgboost.model import XGBoostModel
29+
from sagemaker.xgboost.utils import validate_py_version, validate_framework_version
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -101,6 +102,9 @@ def __init__(
101102
self.py_version = py_version
102103
self.framework_version = framework_version
103104

105+
validate_py_version(py_version)
106+
validate_framework_version(framework_version)
107+
104108
if image_uri is None:
105109
self.image_uri = image_uris.retrieve(
106110
self._framework_name,

src/sagemaker/xgboost/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.serializers import LibSVMSerializer
2525
from sagemaker.xgboost.defaults import XGBOOST_NAME
26+
from sagemaker.xgboost.utils import validate_py_version, validate_framework_version
2627

2728
logger = logging.getLogger("sagemaker")
2829

@@ -105,13 +106,15 @@ def __init__(
105106
self.framework_version = framework_version
106107
self.model_server_workers = model_server_workers
107108

109+
validate_py_version(py_version)
110+
validate_framework_version(framework_version)
111+
108112
def prepare_container_def(self, instance_type=None, accelerator_type=None):
109113
"""Return a container definition with framework configuration
110114
set in model environment variables.
111115
112116
Args:
113117
instance_type (str): The EC2 instance type to deploy this Model to.
114-
This parameter is unused because XGBoost supports only CPU.
115118
accelerator_type (str): The Elastic Inference accelerator type to deploy to the
116119
instance for loading and making inferences to the model. This parameter is
117120
unused because accelerator types are not supported by XGBoostModel.
@@ -148,7 +151,5 @@ def serving_image_uri(self, region_name, instance_type):
148151
self._framework_name,
149152
region_name,
150153
version=self.framework_version,
151-
py_version=self.py_version,
152154
instance_type=instance_type,
153-
image_scope="inference",
154155
)

src/sagemaker/xgboost/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.xgboost import defaults
17+
18+
19+
def validate_py_version(py_version):
20+
"""Placeholder docstring"""
21+
if py_version != "py3":
22+
raise ValueError("Unsupported Python version: {}.".format(py_version))
23+
24+
25+
def validate_framework_version(framework_version):
26+
"""Placeholder docstring"""
27+
28+
xgboost_version = framework_version.split("-")[0]
29+
if xgboost_version in defaults.XGBOOST_UNSUPPORTED_VERSIONS:
30+
msg = defaults.XGBOOST_UNSUPPORTED_VERSIONS[xgboost_version]
31+
raise ValueError(msg)

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def xgboost_framework_version(xgboost_version):
184184
return xgboost_version
185185

186186

187+
@pytest.fixture(scope="module")
188+
def xgboost_gpu_framework_version(xgboost_version):
189+
if xgboost_version in ("1", "latest"):
190+
pytest.skip("Skipping XGBoost algorithm version.")
191+
if Version(xgboost_version) < Version("1.2"):
192+
pytest.skip("Skipping XGBoost cpu-only version.")
193+
return xgboost_version
194+
195+
187196
@pytest.fixture(scope="module", params=["py2", "py3"])
188197
def tensorflow_training_py_version(tensorflow_training_version, request):
189198
return _tf_py_version(tensorflow_training_version, request)

tests/integ/test_airflow_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,13 +552,13 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
552552

553553
@pytest.mark.canary_quick
554554
def test_xgboost_airflow_config_uploads_data_source_to_s3(
555-
sagemaker_session, cpu_instance_type, xgboost_latest_version, xgboost_latest_py_version
555+
sagemaker_session, cpu_instance_type, xgboost_latest_version
556556
):
557557
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
558558
xgboost = XGBoost(
559559
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
560560
framework_version=xgboost_latest_version,
561-
py_version=xgboost_latest_py_version,
561+
py_version="py3",
562562
role=ROLE,
563563
sagemaker_session=sagemaker_session,
564564
instance_type=cpu_instance_type,

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525

2626
def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", region=REGION):
2727
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
28-
tag = "{}-{}".format(fw_version, processor)
29-
if py_version:
30-
tag = "-".join((tag, py_version))
28+
tag = "-".join(x for x in (fw_version, processor, py_version) if x)
3129

3230
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
3331

tests/unit/sagemaker/image_uris/test_xgboost.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
"us-west-2": "433757028032",
4545
}
4646
ALGO_VERSIONS = ("1", "latest")
47-
XGBOOST_FRAMEWORK_VERSIONS = ("0.90-2", "0.90-1", "1.0-1")
47+
XGBOOST_FRAMEWORK_CPU_ONLY_VERSIONS = ("0.90-2", "0.90-1", "1.0-1")
48+
XGBOOST_FRAMEWORK_CPU_GPU_VERSIONS = ("1.2-1",)
4849

4950
FRAMEWORK_REGISTRIES = {
5051
"af-south-1": "510948584623",
@@ -74,22 +75,42 @@
7475
}
7576

7677

77-
@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_VERSIONS)
78+
@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_GPU_VERSIONS)
7879
def test_xgboost_framework(xgboost_framework_version):
7980
for region in regions.regions():
8081
uri = image_uris.retrieve(
8182
framework="xgboost",
8283
region=region,
8384
version=xgboost_framework_version,
84-
py_version="py3",
8585
)
8686

8787
expected = expected_uris.framework_uri(
8888
"sagemaker-xgboost",
8989
xgboost_framework_version,
9090
FRAMEWORK_REGISTRIES[region],
91-
py_version="py3",
91+
py_version=None,
92+
processor=None,
93+
region=region,
94+
)
95+
assert expected == uri
96+
97+
98+
@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_ONLY_VERSIONS)
99+
def test_xgboost_framework_cpu_only(xgboost_framework_version):
100+
for region in regions.regions():
101+
uri = image_uris.retrieve(
102+
framework="xgboost",
92103
region=region,
104+
version=xgboost_framework_version,
105+
)
106+
107+
expected = expected_uris.framework_uri(
108+
"sagemaker-xgboost",
109+
xgboost_framework_version,
110+
FRAMEWORK_REGISTRIES[region],
111+
region=region,
112+
py_version="py3",
113+
processor="cpu",
93114
)
94115
assert expected == uri
95116

0 commit comments

Comments
 (0)