Skip to content

Commit 78f1a69

Browse files
chuyang-dengDan
authored andcommitted
change: add py2 deprecation message for the deep learning framework images (#768)
1 parent 28c23bc commit 78f1a69

File tree

11 files changed

+83
-13
lines changed

11 files changed

+83
-13
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import logging
1616

1717
from sagemaker.estimator import Framework
18-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
19+
python_deprecation_warning
1920
from sagemaker.chainer.defaults import CHAINER_VERSION
2021
from sagemaker.chainer.model import ChainerModel
2122
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -90,6 +91,10 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
9091

9192
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
9293
image_name=image_name, **kwargs)
94+
95+
if py_version == 'py2':
96+
logger.warning(python_deprecation_warning(self.__framework_name__))
97+
9398
self.py_version = py_version
9499
self.use_mpi = use_mpi
95100
self.num_processes = num_processes

src/sagemaker/chainer/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
18+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
1719
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1820
from sagemaker.chainer.defaults import CHAINER_VERSION
1921
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2022

23+
logger = logging.getLogger('sagemaker')
24+
2125

2226
class ChainerPredictor(RealTimePredictor):
2327
"""A RealTimePredictor for inference against Chainer Endpoints.
@@ -66,6 +70,9 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3',
6670
"""
6771
super(ChainerModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
6872
**kwargs)
73+
if py_version == 'py2':
74+
logger.warning(python_deprecation_warning(self.__framework_name__))
75+
6976
self.py_version = py_version
7077
self.framework_version = framework_version
7178
self.model_server_workers = model_server_workers

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
LATER_FRAMEWORK_VERSION_WARNING = 'This is not the latest supported version. ' \
3737
'If you would like to use version {latest}, ' \
3838
'please add framework_version={latest} to your constructor.'
39+
PYTHON_2_DEPRECATION_WARNING = 'The Python 2 {framework} images will be soon deprecated and may not be ' \
40+
'supported for newer upcoming versions of the {framework} images.\n' \
41+
'Please set the argument \"py_version=\'py3\'\" to use the Python 3 {framework} image.'
42+
3943

4044
EMPTY_FRAMEWORK_VERSION_ERROR = 'framework_version is required for script mode estimator. ' \
4145
'Please add framework_version={} to your constructor to avoid this error.'
@@ -303,3 +307,7 @@ def empty_framework_version_warning(default_version, latest_version):
303307
if default_version != latest_version:
304308
msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version))
305309
return ' '.join(msgs)
310+
311+
312+
def python_deprecation_warning(framework):
313+
return PYTHON_2_DEPRECATION_WARNING.format(framework=framework)

src/sagemaker/mxnet/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import logging
1616

1717
from sagemaker.estimator import Framework
18-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
19+
python_deprecation_warning
1920
from sagemaker.mxnet.defaults import MXNET_VERSION
2021
from sagemaker.mxnet.model import MXNetModel
2122
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -79,6 +80,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
7980

8081
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
8182
image_name=image_name, **kwargs)
83+
84+
if py_version == 'py2':
85+
logger.warning(python_deprecation_warning(self.__framework_name__))
86+
8287
self.py_version = py_version
8388
self._configure_distribution(distributions)
8489

src/sagemaker/mxnet/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
18+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
1719
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1820
from sagemaker.mxnet.defaults import MXNET_VERSION
1921
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
2022

23+
logger = logging.getLogger('sagemaker')
24+
2125

2226
class MXNetPredictor(RealTimePredictor):
2327
"""A RealTimePredictor for inference against MXNet Endpoints.
@@ -66,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
6670
"""
6771
super(MXNetModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
6872
**kwargs)
73+
74+
if py_version == 'py2':
75+
logger.warning(python_deprecation_warning(self.__framework_name__))
76+
6977
self.py_version = py_version
7078
self.framework_version = framework_version
7179
self.model_server_workers = model_server_workers

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import logging
1616

1717
from sagemaker.estimator import Framework
18-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \
19+
python_deprecation_warning
1920
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
2021
from sagemaker.pytorch.model import PyTorchModel
2122
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -74,6 +75,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
7475
self.framework_version = framework_version or PYTORCH_VERSION
7576

7677
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs)
78+
79+
if py_version == 'py2':
80+
logger.warning(python_deprecation_warning(self.__framework_name__))
81+
7782
self.py_version = py_version
7883

7984
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):

src/sagemaker/pytorch/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
import logging
16+
1417
import sagemaker
15-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
18+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
1619
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1720
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
1821
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
1922

23+
logger = logging.getLogger('sagemaker')
24+
2025

2126
class PyTorchPredictor(RealTimePredictor):
2227
"""A RealTimePredictor for inference against PyTorch Endpoints.
@@ -65,6 +70,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_
6570
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
6671
"""
6772
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
73+
74+
if py_version == 'py2':
75+
logger.warning(python_deprecation_warning(self.__framework_name__))
76+
6877
self.py_version = py_version
6978
self.framework_version = framework_version
7079
self.model_server_workers = model_server_workers

src/sagemaker/sklearn/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from sagemaker.estimator import Framework
1818
from sagemaker.fw_registry import default_framework_uri
19-
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning
19+
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning, python_deprecation_warning
2020
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME
2121
from sagemaker.sklearn.model import SKLearnModel
2222
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
@@ -79,6 +79,9 @@ def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=No
7979
super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name,
8080
**dict(kwargs, train_instance_count=1))
8181

82+
if py_version == 'py2':
83+
logger.warning(python_deprecation_warning(self.__framework_name__))
84+
8285
self.py_version = py_version
8386

8487
if framework_version is None:

src/sagemaker/sklearn/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
import sagemaker
16-
from sagemaker.fw_utils import model_code_key_prefix
18+
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
1719
from sagemaker.fw_registry import default_framework_uri
1820
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1921
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
2022
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME
2123

24+
logger = logging.getLogger('sagemaker')
25+
2226

2327
class SKLearnPredictor(RealTimePredictor):
2428
"""A RealTimePredictor for inference against Scikit-learn Endpoints.
@@ -68,6 +72,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3',
6872
"""
6973
super(SKLearnModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
7074
**kwargs)
75+
76+
if py_version == 'py2':
77+
logger.warning(python_deprecation_warning(self.__framework_name__))
78+
7179
self.py_version = py_version
7280
self.framework_version = framework_version
7381
self.model_server_workers = model_server_workers

src/sagemaker/tensorflow/estimator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.utils import get_config_value
3030
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
3131

32-
LOGGER = logging.getLogger('sagemaker')
32+
logger = logging.getLogger('sagemaker')
3333

3434

3535
_FRAMEWORK_MODE_ARGS = ('training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path')
@@ -154,7 +154,7 @@ def run(self):
154154
"""Run TensorBoard process."""
155155
port, tensorboard_process = self.create_tensorboard_process()
156156

157-
LOGGER.info('TensorBoard 0.1.7 at http://localhost:{}'.format(port))
157+
logger.info('TensorBoard 0.1.7 at http://localhost:{}'.format(port))
158158
while not self.estimator.checkpoint_path:
159159
self.event.wait(1)
160160
with self._temporary_directory() as aws_sync_dir:
@@ -231,11 +231,15 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
231231
**kwargs: Additional kwargs passed to the Framework constructor.
232232
"""
233233
if framework_version is None:
234-
LOGGER.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
234+
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
235235
self.framework_version = framework_version or TF_VERSION
236236

237237
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
238238
self.checkpoint_path = checkpoint_path
239+
240+
if py_version == 'py2':
241+
logger.warning('tensorflow py2 container will be deprecated soon.')
242+
239243
self.py_version = py_version
240244
self.training_steps = training_steps
241245
self.evaluation_steps = evaluation_steps
@@ -320,7 +324,7 @@ def fit_super():
320324
raise ValueError("Tensorboard is not supported with async fit")
321325

322326
if self._script_mode_enabled() and run_tensorboard_locally:
323-
LOGGER.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
327+
logger.warning(_SCRIPT_MODE_TENSORBOARD_WARNING.format(self.model_dir))
324328
fit_super()
325329
elif run_tensorboard_locally:
326330
tensorboard = Tensorboard(self)

src/sagemaker/tensorflow/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
18+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
1719
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1820
from sagemaker.predictor import RealTimePredictor
1921
from sagemaker.tensorflow.defaults import TF_VERSION
2022
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
2123

24+
logger = logging.getLogger('sagemaker')
25+
2226

2327
class TensorFlowPredictor(RealTimePredictor):
2428
"""A ``RealTimePredictor`` for inference against TensorFlow endpoint.
@@ -67,6 +71,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
6771
"""
6872
super(TensorFlowModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
6973
**kwargs)
74+
75+
if py_version == 'py2':
76+
logger.warning(python_deprecation_warning(self.__framework_name__))
77+
7078
self.py_version = py_version
7179
self.framework_version = framework_version
7280
self.model_server_workers = model_server_workers

0 commit comments

Comments
 (0)