Skip to content

Commit eee0a8e

Browse files
committed
add warning when creating a constructor
1 parent 385d40a commit eee0a8e

File tree

5 files changed

+49
-13
lines changed

5 files changed

+49
-13
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
1719
from sagemaker.chainer.defaults import CHAINER_VERSION
1820
from sagemaker.chainer.model import ChainerModel
1921
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2022

23+
logging.basicConfig()
24+
logger = logging.getLogger('sagemaker')
25+
2126

2227
class Chainer(Framework):
2328
"""Handle end-to-end training and deployment of custom Chainer code."""
@@ -32,7 +37,7 @@ class Chainer(Framework):
3237

3338
def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None,
3439
additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3',
35-
framework_version=CHAINER_VERSION, image_name=None, **kwargs):
40+
framework_version=None, image_name=None, **kwargs):
3641
"""
3742
This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker
3843
Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions
@@ -79,12 +84,15 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_
7984
super(Chainer, self).__init__(entry_point, source_dir, hyperparameters,
8085
image_name=image_name, **kwargs)
8186
self.py_version = py_version
82-
self.framework_version = framework_version
8387
self.use_mpi = use_mpi
8488
self.num_processes = num_processes
8589
self.process_slots_per_host = process_slots_per_host
8690
self.additional_mpi_options = additional_mpi_options
8791

92+
if framework_version is None:
93+
logger.warning(empty_framework_version_warning(CHAINER_VERSION))
94+
self.framework_version = framework_version or CHAINER_VERSION
95+
8896
def hyperparameters(self):
8997
"""Return hyperparameters used by your custom Chainer code during training."""
9098
hyperparameters = super(Chainer, self).hyperparameters()

src/sagemaker/fw_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@
2323

2424
"""This module contains utility functions shared across ``Framework`` components."""
2525

26-
2726
UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
2827
"""sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name.
2928
3029
This is for the source code used for the entry point with an ``Estimator``. It can be
3130
instantiated with positional or keyword arguments.
3231
"""
3332

33+
EMPTY_FRAMEWORK_VERSION_WARNING = 'In an upcoming version of the SageMaker Python SDK, ' \
34+
'framework_version will be required to create an estimator. ' \
35+
'Please add framework_version={} to your constructor to avoid ' \
36+
'an error in the future.'
37+
3438

3539
def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638',
3640
optimized_families=[]):
@@ -223,3 +227,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
223227
str: the key prefix to be used in uploading code
224228
"""
225229
return '/'.join(filter(None, [code_location_key_prefix, model_name or name_from_image(image)]))
230+
231+
232+
def empty_framework_version_warning(default_version):
233+
return EMPTY_FRAMEWORK_VERSION_WARNING.format(default_version)

src/sagemaker/mxnet/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,25 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
from sagemaker.estimator import Framework
16-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
1719
from sagemaker.mxnet.defaults import MXNET_VERSION
1820
from sagemaker.mxnet.model import MXNetModel
1921
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2022

23+
logging.basicConfig()
24+
logger = logging.getLogger('sagemaker')
25+
2126

2227
class MXNet(Framework):
2328
"""Handle end-to-end training and deployment of custom MXNet code."""
2429

2530
__framework_name__ = "mxnet"
2631

2732
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
28-
framework_version=MXNET_VERSION, image_name=None, **kwargs):
33+
framework_version=None, image_name=None, **kwargs):
2934
"""
3035
This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker
3136
Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions
@@ -64,7 +69,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6469
super(MXNet, self).__init__(entry_point, source_dir, hyperparameters,
6570
image_name=image_name, **kwargs)
6671
self.py_version = py_version
67-
self.framework_version = framework_version
72+
73+
if framework_version is None:
74+
logger.warning(empty_framework_version_warning(MXNET_VERSION))
75+
self.framework_version = framework_version or MXNET_VERSION
6876

6977
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
7078
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``.

src/sagemaker/pytorch/estimator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
15+
import logging
16+
1417
from sagemaker.estimator import Framework
15-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
18+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
1619
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
1720
from sagemaker.pytorch.model import PyTorchModel
1821
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
1922

23+
logging.basicConfig()
24+
logger = logging.getLogger('sagemaker')
25+
2026

2127
class PyTorch(Framework):
2228
"""Handle end-to-end training and deployment of custom PyTorch code."""
2329

2430
__framework_name__ = "pytorch"
2531

2632
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
27-
framework_version=PYTORCH_VERSION, image_name=None, **kwargs):
33+
framework_version=None, image_name=None, **kwargs):
2834
"""
2935
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
3036
Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions
@@ -62,7 +68,10 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
6268
"""
6369
super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs)
6470
self.py_version = py_version
65-
self.framework_version = framework_version
71+
72+
if framework_version is None:
73+
logger.warning(empty_framework_version_warning(PYTORCH_VERSION))
74+
self.framework_version = framework_version or PYTORCH_VERSION
6675

6776
def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
6877
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.

src/sagemaker/tensorflow/estimator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import time
2323

2424
from sagemaker.estimator import Framework
25-
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag
25+
from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning
2626
from sagemaker.utils import get_config_value
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

@@ -159,7 +159,7 @@ class TensorFlow(Framework):
159159
__framework_name__ = 'tensorflow'
160160

161161
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
162-
framework_version=TF_VERSION, requirements_file='', image_name=None, **kwargs):
162+
framework_version=None, requirements_file='', image_name=None, **kwargs):
163163
"""Initialize an ``TensorFlow`` estimator.
164164
Args:
165165
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -184,13 +184,16 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
184184
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
185185
self.checkpoint_path = checkpoint_path
186186
self.py_version = py_version
187-
self.framework_version = framework_version
188187
self.training_steps = training_steps
189188
self.evaluation_steps = evaluation_steps
190189

191190
self._validate_requirements_file(requirements_file)
192191
self.requirements_file = requirements_file
193192

193+
if framework_version is None:
194+
LOGGER.warning(empty_framework_version_warning(TF_VERSION))
195+
self.framework_version = framework_version or TF_VERSION
196+
194197
def _validate_requirements_file(self, requirements_file):
195198
if not requirements_file:
196199
return

0 commit comments

Comments
 (0)