Skip to content

Commit b214078

Browse files
Edward J Kimchuyang-deng
authored andcommitted
feature: add support for SKLearn 0.23 (#1561)
Co-authored-by: Chuyang <[email protected]>
1 parent f586b3f commit b214078

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

src/sagemaker/fw_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,10 +587,18 @@ def empty_framework_version_warning(default_version, latest_version):
587587
"""
588588
msgs = [EMPTY_FRAMEWORK_VERSION_WARNING.format(default_version)]
589589
if default_version != latest_version:
590-
msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version))
590+
msgs.append(later_framework_version_warning(latest_version))
591591
return " ".join(msgs)
592592

593593

594+
def later_framework_version_warning(latest_version):
595+
"""
596+
Args:
597+
latest_version:
598+
"""
599+
return LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version)
600+
601+
594602
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions):
595603
"""Warn the user that training will not fully leverage all the GPU
596604
cores if parameter server is enabled and a multi-GPU instance is selected.

src/sagemaker/sklearn/defaults.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
SKLEARN_NAME = "scikit-learn"
1717

18+
# Default SKLearn version for when the framework version is not specified.
19+
# This is no longer updated so as to not break existing workflows.
1820
SKLEARN_VERSION = "0.20.0"
21+
SKLEARN_LATEST_VERSION = "0.23-1"
22+
SKLEARN_SUPPORTED_VERSIONS = [SKLEARN_VERSION, SKLEARN_LATEST_VERSION]
23+
1924

2025
LATEST_PY2_VERSION = "0.20.0"

src/sagemaker/sklearn/estimator.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from sagemaker.fw_registry import default_framework_uri
2020
from sagemaker.fw_utils import (
2121
framework_name_from_image,
22-
empty_framework_version_warning,
22+
get_unsupported_framework_version_error,
23+
later_framework_version_warning,
2324
python_deprecation_warning,
2425
)
2526
from sagemaker.sklearn import defaults
@@ -126,11 +127,17 @@ def __init__(
126127

127128
self.py_version = py_version
128129

129-
if framework_version is None:
130-
logger.warning(
131-
empty_framework_version_warning(defaults.SKLEARN_VERSION, defaults.SKLEARN_VERSION)
130+
if framework_version in defaults.SKLEARN_SUPPORTED_VERSIONS:
131+
self.framework_version = framework_version
132+
else:
133+
raise ValueError(
134+
get_unsupported_framework_version_error(
135+
self.__framework_name__, framework_version, defaults.SKLEARN_SUPPORTED_VERSIONS
136+
)
132137
)
133-
self.framework_version = framework_version or defaults.SKLEARN_VERSION
138+
139+
if framework_version != defaults.SKLEARN_LATEST_VERSION:
140+
logger.warning(later_framework_version_warning(defaults.SKLEARN_LATEST_VERSION))
134141

135142
if image_name is None:
136143
image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version)

tests/unit/test_sklearn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,35 @@ def test_estimator_py2_warning(warning, sagemaker_session):
574574
warning.assert_called_with(estimator.__framework_name__, defaults.LATEST_PY2_VERSION)
575575

576576

577+
@patch("sagemaker.sklearn.estimator.later_framework_version_warning")
578+
def test_estimator_later_framework_version_warning(warning, sagemaker_session):
579+
estimator = SKLearn(
580+
entry_point=SCRIPT_PATH,
581+
role=ROLE,
582+
sagemaker_session=sagemaker_session,
583+
train_instance_count=INSTANCE_COUNT,
584+
train_instance_type=INSTANCE_TYPE,
585+
)
586+
587+
assert estimator.framework_version == defaults.SKLEARN_VERSION
588+
warning.assert_called_with(defaults.SKLEARN_LATEST_VERSION)
589+
590+
591+
@patch("sagemaker.sklearn.estimator.get_unsupported_framework_version_error")
592+
def test_estimator_throws_error_for_unsupported_version(error, sagemaker_session):
593+
with pytest.raises(ValueError):
594+
estimator = SKLearn(
595+
entry_point=SCRIPT_PATH,
596+
role=ROLE,
597+
sagemaker_session=sagemaker_session,
598+
train_instance_count=INSTANCE_COUNT,
599+
train_instance_type=INSTANCE_TYPE,
600+
framework_version="foo",
601+
)
602+
assert estimator.framework_version not in defaults.SKLEARN_SUPPORTED_VERSIONS
603+
error.assert_called_with(defaults.SKLEARN_NAME, "foo", defaults.SKLEARN_SUPPORT_VERSIONS)
604+
605+
577606
@patch("sagemaker.sklearn.model.python_deprecation_warning")
578607
def test_model_py2_warning(warning, sagemaker_session):
579608
source_dir = "s3://mybucket/source"

0 commit comments

Comments
 (0)