Skip to content

Commit f2ccf64

Browse files
committed
breaking: require framework_version, py_version for mxnet
* framework_version, py_version required for framework MXNet * framework_version, py_version required for framework MXNetModel * re-order of non-default args, convention to follow entry_point * testing updates * doc updates * ignore coverage results for py27 env due to v2 migration scripts
1 parent 97cd594 commit f2ccf64

File tree

10 files changed

+231
-173
lines changed

10 files changed

+231
-173
lines changed

doc/frameworks/mxnet/using_mxnet.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ The following code sample shows how you train a custom MXNet script "train.py".
345345
mxnet_estimator = MXNet('train.py',
346346
train_instance_type='ml.p2.xlarge',
347347
train_instance_count=1,
348-
framework_version='1.3.0',
348+
framework_version='1.6.0',
349+
py_version='py3',
349350
hyperparameters={'batch-size': 100,
350351
'epochs': 10,
351352
'learning-rate': 0.1})
@@ -392,10 +393,10 @@ If you use the ``MXNet`` estimator to train the model, you can call ``deploy`` t
392393
393394
# Train my estimator
394395
mxnet_estimator = MXNet('train.py',
395-
train_instance_type='ml.p2.xlarge',
396-
train_instance_count=1,
396+
framework_version='1.6.0',
397397
py_version='py3',
398-
framework_version='1.6.0')
398+
train_instance_type='ml.p2.xlarge',
399+
train_instance_count=1)
399400
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
400401
401402
# Deploy my estimator to an Amazon SageMaker Endpoint and get a Predictor
@@ -409,8 +410,8 @@ If using a pretrained model, create an ``MXNetModel`` object, and then call ``de
409410
mxnet_model = MXNetModel(model_data='s3://my_bucket/pretrained_model/model.tar.gz',
410411
role=role,
411412
entry_point='inference.py',
412-
py_version='py3',
413-
framework_version='1.6.0')
413+
framework_version='1.6.0',
414+
py_version='py3')
414415
predictor = mxnet_model.deploy(instance_type='ml.m4.xlarge',
415416
initial_instance_count=1)
416417

src/sagemaker/cli/mxnet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.common import HostCommand, TrainCommand
17+
from sagemaker.mxnet import defaults
1718

1819

1920
def train(args):
@@ -40,13 +41,14 @@ def create_estimator(self):
4041
from sagemaker.mxnet.estimator import MXNet
4142

4243
return MXNet(
43-
self.script,
44+
entry_point=self.script,
45+
framework_version=defaults.MXNET_VERSION,
46+
py_version=self.python,
4447
role=self.role_name,
4548
base_job_name=self.job_name,
4649
train_instance_count=self.instance_count,
4750
train_instance_type=self.instance_type,
4851
hyperparameters=self.hyperparameters,
49-
py_version=self.python,
5052
)
5153

5254

@@ -64,6 +66,7 @@ def create_model(self, model_url):
6466
model_data=model_url,
6567
role=self.role_name,
6668
entry_point=self.script,
69+
framework_version=defaults.MXNET_VERSION,
6770
py_version=self.python,
6871
name=self.endpoint_name,
6972
env=self.environment,

src/sagemaker/mxnet/estimator.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
2322
python_deprecation_warning,
2423
parameter_v2_rename_warning,
2524
is_version_equal_or_higher,
@@ -43,10 +42,10 @@ class MXNet(Framework):
4342
def __init__(
4443
self,
4544
entry_point,
45+
framework_version,
46+
py_version,
4647
source_dir=None,
4748
hyperparameters=None,
48-
py_version="py2",
49-
framework_version=None,
5049
image_name=None,
5150
distributions=None,
5251
**kwargs
@@ -73,6 +72,11 @@ def __init__(
7372
file which should be executed as the entry point to training.
7473
If ``source_dir`` is specified, then ``entry_point``
7574
must point to a file located at the root of ``source_dir``.
75+
framework_version (str): MXNet version you want to use for executing
76+
your model training code. List of supported versions
77+
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
78+
py_version (str): Python version you want to use for executing your
79+
model training code. One of 'py2' or 'py3'.
7680
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7781
with any other training source code dependencies aside from the entry
7882
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -84,12 +88,6 @@ def __init__(
8488
SageMaker. For convenience, this accepts other types for keys
8589
and values, but ``str()`` will be called to convert them before
8690
training.
87-
py_version (str): Python version you want to use for executing your
88-
model training code (default: 'py2'). One of 'py2' or 'py3'.
89-
framework_version (str): MXNet version you want to use for executing
90-
your model training code. List of supported versions
91-
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
92-
If not specified, this will default to 1.2.1.
9391
image_name (str): If specified, the estimator will use this image for training and
9492
hosting, instead of selecting the appropriate SageMaker official image based on
9593
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -110,34 +108,31 @@ def __init__(
110108
:class:`~sagemaker.estimator.Framework` and
111109
:class:`~sagemaker.estimator.EstimatorBase`.
112110
"""
113-
if framework_version is None:
111+
self.framework_version = framework_version
112+
self.py_version = py_version
113+
if self.py_version == "py2":
114114
logger.warning(
115-
empty_framework_version_warning(defaults.MXNET_VERSION, self.LATEST_VERSION)
115+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
116116
)
117-
self.framework_version = framework_version or defaults.MXNET_VERSION
118117

119118
if "enable_sagemaker_metrics" not in kwargs:
120119
# enable sagemaker metrics for MXNet v1.6 or greater:
121-
if is_version_equal_or_higher([1, 6], self.framework_version):
120+
if self.framework_version and is_version_equal_or_higher(
121+
[1, 6], self.framework_version
122+
):
122123
kwargs["enable_sagemaker_metrics"] = True
123124

124125
super(MXNet, self).__init__(
125126
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
126127
)
127128

128-
if py_version == "py2":
129-
logger.warning(
130-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
131-
)
132-
133129
if distributions is not None:
134130
logger.warning(parameter_v2_rename_warning("distributions", "distribution"))
135131
train_instance_type = kwargs.get("train_instance_type")
136132
warn_if_parameter_server_with_multi_gpu(
137133
training_instance_type=train_instance_type, distributions=distributions
138134
)
139135

140-
self.py_version = py_version
141136
self._configure_distribution(distributions)
142137

143138
def _configure_distribution(self, distributions):
@@ -221,12 +216,12 @@ def create_model(
221216
self.model_data,
222217
role or self.role,
223218
entry_point or self.entry_point,
219+
framework_version=self.framework_version,
220+
py_version=self.py_version,
224221
source_dir=(source_dir or self._model_source_dir()),
225222
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
226223
container_log_level=self.container_log_level,
227224
code_location=self.code_location,
228-
py_version=self.py_version,
229-
framework_version=self.framework_version,
230225
model_server_workers=model_server_workers,
231226
sagemaker_session=self.sagemaker_session,
232227
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -257,6 +252,8 @@ class constructor
257252
if not framework:
258253
# If we were unable to parse the framework name from the image it is not one of our
259254
# officially supported images, in this case just add the image to the init params.
255+
init_params["framework_version"] = None
256+
init_params["py_version"] = None
260257
init_params["image_name"] = image_name
261258
return init_params
262259

src/sagemaker/mxnet/model.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
import packaging.version
1919

2020
import sagemaker
21-
from sagemaker.fw_utils import (
22-
create_image_uri,
23-
model_code_key_prefix,
24-
python_deprecation_warning,
25-
empty_framework_version_warning,
26-
)
21+
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
2722
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2823
from sagemaker.mxnet import defaults
2924
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
@@ -65,9 +60,9 @@ def __init__(
6560
model_data,
6661
role,
6762
entry_point,
63+
framework_version,
64+
py_version,
6865
image=None,
69-
py_version="py2",
70-
framework_version=None,
7166
predictor_cls=MXNetPredictor,
7267
model_server_workers=None,
7368
**kwargs
@@ -86,12 +81,12 @@ def __init__(
8681
file which should be executed as the entry point to model
8782
hosting. If ``source_dir`` is specified, then ``entry_point``
8883
must point to a file located at the root of ``source_dir``.
89-
image (str): A Docker image URI (default: None). If not specified, a
90-
default image for MXNet will be used.
91-
py_version (str): Python version you want to use for executing your
92-
model training code (default: 'py2').
9384
framework_version (str): MXNet version you want to use for executing
9485
your model training code.
86+
py_version (str): Python version you want to use for executing your
87+
model training code.
88+
image (str): A Docker image URI (default: None). If not specified, a
89+
default image for MXNet will be used.
9590
predictor_cls (callable[str, sagemaker.session.Session]): A function
9691
to call to create a predictor with an endpoint name and
9792
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -108,22 +103,17 @@ def __init__(
108103
:class:`~sagemaker.model.FrameworkModel` and
109104
:class:`~sagemaker.model.Model`.
110105
"""
111-
super(MXNetModel, self).__init__(
112-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
113-
)
114-
115-
if py_version == "py2":
106+
self.framework_version = framework_version
107+
self.py_version = py_version
108+
if self.py_version == "py2":
116109
logger.warning(
117110
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118111
)
119112

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.MXNET_VERSION, defaults.LATEST_VERSION)
123-
)
113+
super(MXNetModel, self).__init__(
114+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
115+
)
124116

125-
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.MXNET_VERSION
127117
self.model_server_workers = model_server_workers
128118

129119
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def mxnet_version(request):
163163
return request.param
164164

165165

166+
@pytest.fixture(scope="module", params=["py2", "py3"])
167+
def mxnet_py_version(request):
168+
return request.param
169+
170+
166171
@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"])
167172
def pytorch_version(request):
168173
return request.param

tests/unit/test_multidatamodel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
MXNET_MODEL_DATA = "s3://mybucket/mxnet_path/model.tar.gz"
2828
MXNET_MODEL_NAME = "dummy-mxnet-model"
2929
MXNET_ROLE = "DummyMXNetRole"
30-
MXNET_IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.2-cpu-py2"
30+
MXNET_FRAMEWORK_VERSION = "1.2"
31+
MXNET_PY_VERSION = "py2"
32+
MXNET_IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-cpu-{}".format(
33+
MXNET_FRAMEWORK_VERSION, MXNET_PY_VERSION
34+
)
3135

3236
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
3337
IMAGE = "123456789012.dkr.ecr.dummyregion.amazonaws.com/dummyimage:latest"
@@ -100,8 +104,10 @@ def multi_data_model(sagemaker_session):
100104
def mxnet_model(sagemaker_session):
101105
return MXNetModel(
102106
MXNET_MODEL_DATA,
103-
role=MXNET_ROLE,
104107
entry_point=ENTRY_POINT,
108+
framework_version=MXNET_FRAMEWORK_VERSION,
109+
py_version=MXNET_PY_VERSION,
110+
role=MXNET_ROLE,
105111
sagemaker_session=sagemaker_session,
106112
name=MXNET_MODEL_NAME,
107113
enable_network_isolation=True,

0 commit comments

Comments
 (0)