Skip to content

Commit 926690e

Browse files
authored
fix: add default framework version warning message in Model classes (#1218)
* fix: add default framework version warning message in Model classes * fix black-format errors * add LATEST_VERSION in frameowrk defaults.py
1 parent 0e7c211 commit 926690e

File tree

16 files changed

+115
-28
lines changed

16 files changed

+115
-28
lines changed

src/sagemaker/chainer/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@
1717
"""Default Chainer version for when the framework version is not specified.
1818
This is no longer updated so as to not break existing workflows.
1919
"""
20+
21+
LATEST_VERSION = "5.0.0"
22+
"""The latest version of Chainer included in the SageMaker pre-built Docker images."""

src/sagemaker/chainer/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
2424
)
25-
from sagemaker.chainer.defaults import CHAINER_VERSION
25+
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION
2626
from sagemaker.chainer.model import ChainerModel
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

@@ -40,8 +40,7 @@ class Chainer(Framework):
4040
_process_slots_per_host = "sagemaker_process_slots_per_host"
4141
_additional_mpi_options = "sagemaker_additional_mpi_options"
4242

43-
LATEST_VERSION = "5.0.0"
44-
"""The latest version of Chainer included in the SageMaker pre-built Docker images."""
43+
LATEST_VERSION = LATEST_VERSION
4544

4645
def __init__(
4746
self,

src/sagemaker/chainer/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@
1616
import logging
1717

1818
import sagemaker
19-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
19+
from sagemaker.fw_utils import (
20+
create_image_uri,
21+
model_code_key_prefix,
22+
python_deprecation_warning,
23+
empty_framework_version_warning,
24+
)
2025
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
21-
from sagemaker.chainer.defaults import CHAINER_VERSION
26+
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION
2227
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2328

2429
logger = logging.getLogger("sagemaker")
@@ -61,7 +66,7 @@ def __init__(
6166
entry_point,
6267
image=None,
6368
py_version="py3",
64-
framework_version=CHAINER_VERSION,
69+
framework_version=None,
6570
predictor_cls=ChainerPredictor,
6671
model_server_workers=None,
6772
**kwargs
@@ -107,9 +112,11 @@ def __init__(
107112
)
108113
if py_version == "py2":
109114
logger.warning(python_deprecation_warning(self.__framework_name__))
115+
if framework_version is None:
116+
logger.warning(empty_framework_version_warning(CHAINER_VERSION, LATEST_VERSION))
110117

111118
self.py_version = py_version
112-
self.framework_version = framework_version
119+
self.framework_version = framework_version or CHAINER_VERSION
113120
self.model_server_workers = model_server_workers
114121

115122
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/mxnet/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@
1717
"""Default MXNet version for when the framework version is not specified.
1818
This is no longer updated so as to not break existing workflows.
1919
"""
20+
21+
LATEST_VERSION = "1.6.0"
22+
"""The latest version of MXNet included in the SageMaker pre-built Docker images."""

src/sagemaker/mxnet/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
python_deprecation_warning,
2424
is_version_equal_or_higher,
2525
)
26-
from sagemaker.mxnet.defaults import MXNET_VERSION
26+
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION
2727
from sagemaker.mxnet.model import MXNetModel
2828
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2929

@@ -36,8 +36,7 @@ class MXNet(Framework):
3636
__framework_name__ = "mxnet"
3737
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]
3838

39-
LATEST_VERSION = "1.6.0"
40-
"""The latest version of MXNet included in the SageMaker pre-built Docker images."""
39+
LATEST_VERSION = LATEST_VERSION
4140

4241
def __init__(
4342
self,

src/sagemaker/mxnet/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@
1818
from pkg_resources import parse_version
1919

2020
import sagemaker
21-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
21+
from sagemaker.fw_utils import (
22+
create_image_uri,
23+
model_code_key_prefix,
24+
python_deprecation_warning,
25+
empty_framework_version_warning,
26+
)
2227
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
23-
from sagemaker.mxnet.defaults import MXNET_VERSION
28+
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION
2429
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
2530

2631
logger = logging.getLogger("sagemaker")
@@ -62,7 +67,7 @@ def __init__(
6267
entry_point,
6368
image=None,
6469
py_version="py2",
65-
framework_version=MXNET_VERSION,
70+
framework_version=None,
6671
predictor_cls=MXNetPredictor,
6772
model_server_workers=None,
6873
**kwargs
@@ -109,9 +114,11 @@ def __init__(
109114

110115
if py_version == "py2":
111116
logger.warning(python_deprecation_warning(self.__framework_name__))
117+
if framework_version is None:
118+
logger.warning(empty_framework_version_warning(MXNET_VERSION, LATEST_VERSION))
112119

113120
self.py_version = py_version
114-
self.framework_version = framework_version
121+
self.framework_version = framework_version or MXNET_VERSION
115122
self.model_server_workers = model_server_workers
116123

117124
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/pytorch/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@
1919
break existing workflows.
2020
"""
2121

22+
LATEST_VERSION = "1.3.1"
23+
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""
24+
2225
PYTHON_VERSION = "py3"

src/sagemaker/pytorch/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
python_deprecation_warning,
2424
is_version_equal_or_higher,
2525
)
26-
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
26+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION
2727
from sagemaker.pytorch.model import PyTorchModel
2828
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2929

@@ -35,8 +35,7 @@ class PyTorch(Framework):
3535

3636
__framework_name__ = "pytorch"
3737

38-
LATEST_VERSION = "1.3.1"
39-
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""
38+
LATEST_VERSION = LATEST_VERSION
4039

4140
def __init__(
4241
self,

src/sagemaker/pytorch/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717
import pkg_resources
1818

1919
import sagemaker
20-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
20+
from sagemaker.fw_utils import (
21+
create_image_uri,
22+
model_code_key_prefix,
23+
python_deprecation_warning,
24+
empty_framework_version_warning,
25+
)
2126
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
27+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION
2328
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2429

2530
logger = logging.getLogger("sagemaker")
@@ -63,7 +68,7 @@ def __init__(
6368
entry_point,
6469
image=None,
6570
py_version=PYTHON_VERSION,
66-
framework_version=PYTORCH_VERSION,
71+
framework_version=None,
6772
predictor_cls=PyTorchPredictor,
6873
model_server_workers=None,
6974
**kwargs
@@ -110,9 +115,11 @@ def __init__(
110115

111116
if py_version == "py2":
112117
logger.warning(python_deprecation_warning(self.__framework_name__))
118+
if framework_version is None:
119+
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION))
113120

114121
self.py_version = py_version
115-
self.framework_version = framework_version
122+
self.framework_version = framework_version or PYTORCH_VERSION
116123
self.model_server_workers = model_server_workers
117124

118125
def prepare_container_def(self, instance_type, accelerator_type=None):

src/sagemaker/tensorflow/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@
1717
"""Default TF version for when the framework version is not specified.
1818
This is no longer updated so as to not break existing workflows.
1919
"""
20+
21+
LATEST_VERSION = "2.0.0"
22+
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""

src/sagemaker/tensorflow/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sagemaker.debugger import DebuggerHookConfig
2626
from sagemaker.estimator import Framework
2727
import sagemaker.fw_utils as fw
28-
from sagemaker.tensorflow.defaults import TF_VERSION
28+
from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION
2929
from sagemaker.tensorflow.model import TensorFlowModel
3030
from sagemaker.tensorflow.serving import Model
3131
from sagemaker.transformer import Transformer
@@ -197,8 +197,7 @@ class TensorFlow(Framework):
197197

198198
__framework_name__ = "tensorflow"
199199

200-
LATEST_VERSION = "2.0.0"
201-
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
200+
LATEST_VERSION = LATEST_VERSION
202201

203202
_LATEST_1X_VERSION = "1.15.0"
204203

src/sagemaker/tensorflow/model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
import logging
1717

1818
import sagemaker
19-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
19+
from sagemaker.fw_utils import (
20+
create_image_uri,
21+
model_code_key_prefix,
22+
python_deprecation_warning,
23+
empty_framework_version_warning,
24+
)
2025
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2126
from sagemaker.predictor import RealTimePredictor
22-
from sagemaker.tensorflow.defaults import TF_VERSION
27+
from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION
2328
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
2429

2530
logger = logging.getLogger("sagemaker")
@@ -60,7 +65,7 @@ def __init__(
6065
entry_point,
6166
image=None,
6267
py_version="py2",
63-
framework_version=TF_VERSION,
68+
framework_version=None,
6469
predictor_cls=TensorFlowPredictor,
6570
model_server_workers=None,
6671
**kwargs
@@ -107,9 +112,11 @@ def __init__(
107112

108113
if py_version == "py2":
109114
logger.warning(python_deprecation_warning(self.__framework_name__))
115+
if framework_version is None:
116+
logger.warning(empty_framework_version_warning(TF_VERSION, LATEST_VERSION))
110117

111118
self.py_version = py_version
112-
self.framework_version = framework_version
119+
self.framework_version = framework_version or TF_VERSION
113120
self.model_server_workers = model_server_workers
114121

115122
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/unit/test_chainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,16 @@ def test_empty_framework_version(warning, sagemaker_session):
601601

602602
assert estimator.framework_version == defaults.CHAINER_VERSION
603603
warning.assert_called_with(defaults.CHAINER_VERSION, Chainer.LATEST_VERSION)
604+
605+
606+
@patch("sagemaker.chainer.model.empty_framework_version_warning")
607+
def test_model_empty_framework_version(warning, sagemaker_session):
608+
model = ChainerModel(
609+
MODEL_DATA,
610+
role=ROLE,
611+
entry_point=SCRIPT_PATH,
612+
sagemaker_session=sagemaker_session,
613+
framework_version=None,
614+
)
615+
assert model.framework_version == defaults.CHAINER_VERSION
616+
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)

tests/unit/test_mxnet.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,19 @@ def test_empty_framework_version(warning, sagemaker_session):
722722
warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION)
723723

724724

725+
@patch("sagemaker.mxnet.model.empty_framework_version_warning")
726+
def test_model_empty_framework_version(warning, sagemaker_session):
727+
model = MXNetModel(
728+
MODEL_DATA,
729+
role=ROLE,
730+
entry_point=SCRIPT_PATH,
731+
sagemaker_session=sagemaker_session,
732+
framework_version=None,
733+
)
734+
assert model.framework_version == defaults.MXNET_VERSION
735+
warning.assert_called_with(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
736+
737+
725738
def test_create_model_with_custom_hosting_image(sagemaker_session):
726739
container_log_level = '"logging.INFO"'
727740
source_dir = "s3://mybucket/source"

tests/unit/test_pytorch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,20 @@ def test_empty_framework_version(warning, sagemaker_session):
532532
warning.assert_called_with(defaults.PYTORCH_VERSION, estimator.LATEST_VERSION)
533533

534534

535+
@patch("sagemaker.pytorch.model.empty_framework_version_warning")
536+
def test_model_empty_framework_version(warning, sagemaker_session):
537+
model = PyTorchModel(
538+
MODEL_DATA,
539+
role=ROLE,
540+
entry_point=SCRIPT_PATH,
541+
sagemaker_session=sagemaker_session,
542+
framework_version=None,
543+
)
544+
545+
assert model.framework_version == defaults.PYTORCH_VERSION
546+
warning.assert_called_with(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
547+
548+
535549
def test_pt_enable_sm_metrics(sagemaker_session):
536550
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=True)
537551
assert pytorch.enable_sagemaker_metrics

tests/unit/test_tf_estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,17 @@ def test_empty_framework_version(warning, sagemaker_session):
938938
assert estimator.framework_version == defaults.TF_VERSION
939939
warning.assert_called_with(defaults.TF_VERSION, estimator.LATEST_VERSION)
940940

941+
model = TensorFlowModel(
942+
MODEL_DATA,
943+
role=ROLE,
944+
entry_point=SCRIPT_PATH,
945+
sagemaker_session=sagemaker_session,
946+
framework_version=None,
947+
)
948+
949+
assert model.framework_version == defaults.TF_VERSION
950+
warning.assert_called_with(defaults.TF_VERSION, defaults.LATEST_VERSION)
951+
941952

942953
def _deprecated_args_msg(args):
943954
return "{} are deprecated in script mode. Please do not set {}.".format(

0 commit comments

Comments
 (0)