Skip to content

Commit f2a13ef

Browse files
committed
breaking: require framework_version, py_version for pytorch
* framework_version, py_version required for framework PyTorch * framework_version, py_version required for framework PyTorchModel * re-order of args: convention to follow entry_point, push args with defaults below * testing updates * doc updates * ignore coverage results for py27 env due to v2 migration scripts
1 parent 18e402c commit f2a13ef

File tree

11 files changed

+118
-90
lines changed

11 files changed

+118
-90
lines changed

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ directories ('train' and 'test').
154154
pytorch_estimator = PyTorch('pytorch-train.py',
155155
train_instance_type='ml.p3.2xlarge',
156156
train_instance_count=1,
157-
framework_version='1.0.0',
157+
framework_version='1.5.0',
158+
py_version='py3',
158159
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
159160
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
160161
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -247,7 +248,8 @@ operation.
247248
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
248249
train_instance_type='ml.p3.2xlarge',
249250
train_instance_count=1,
250-
framework_version='1.0.0')
251+
framework_version='1.5.0',
252+
py_version='py3')
251253
pytorch_estimator.fit('s3://my_bucket/my_training_data/')
252254
253255
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/fw_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,22 @@ def _region_supports_debugger(region_name):
681681
682682
"""
683683
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS
684+
685+
686+
def validate_version_or_image_args(framework_version, py_version, image_name):
687+
"""Checks if version or image arguments are specified.
688+
689+
Used to validate framework and model arguments to enforce version or image specification.
690+
Raises ValueError if version or image arguments are not specified.
691+
692+
Args:
693+
framework_version (str): the version of the framework
694+
py_version (str): the version of python
695+
image_name (str): the uri of the image
696+
"""
697+
if (framework_version is None or py_version is None) and image_name is None:
698+
raise ValueError(
699+
"framework_version or py_version was None, yet image_name was also None. "
700+
"Either specify both framework_version and py_version, or specify image_name."
701+
)
702+
return True

src/sagemaker/pytorch/estimator.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
23-
python_deprecation_warning,
2422
is_version_equal_or_higher,
23+
python_deprecation_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.pytorch import defaults
2727
from sagemaker.pytorch.model import PyTorchModel
@@ -40,10 +40,10 @@ class PyTorch(Framework):
4040
def __init__(
4141
self,
4242
entry_point,
43+
framework_version=None,
44+
py_version=None,
4345
source_dir=None,
4446
hyperparameters=None,
45-
py_version=defaults.PYTHON_VERSION,
46-
framework_version=None,
4747
image_name=None,
4848
**kwargs
4949
):
@@ -69,6 +69,11 @@ def __init__(
6969
file which should be executed as the entry point to training.
7070
If ``source_dir`` is specified, then ``entry_point``
7171
must point to a file located at the root of ``source_dir``.
72+
framework_version (str): PyTorch version you want to use for
73+
executing your model training code. Defaults to None. List of supported versions.
74+
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
75+
py_version (str): Python version you want to use for executing your
76+
model training code. One of 'py2' or 'py3'. Defaults to None.
7277
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7378
with any other training source code dependencies aside from the entry
7479
point file (default: None). If ``source_dir`` is an S3 URI, it must
@@ -80,12 +85,6 @@ def __init__(
8085
SageMaker. For convenience, this accepts other types for keys
8186
and values, but ``str()`` will be called to convert them before
8287
training.
83-
py_version (str): Python version you want to use for executing your
84-
model training code (default: 'py3'). One of 'py2' or 'py3'.
85-
framework_version (str): PyTorch version you want to use for
86-
executing your model training code. List of supported versions
87-
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
88-
If not specified, this will default to 0.4.
8988
image_name (str): If specified, the estimator will use this image
9089
for training and hosting, instead of selecting the appropriate
9190
SageMaker official image based on framework_version and
@@ -95,6 +94,9 @@ def __init__(
9594
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
9695
* ``custom-image:latest``
9796
97+
If ``framework_version`` or ``py_version`` are ``None``, then
98+
``image_name`` is required. If also ``None``, then a ``ValueError``
99+
will be raised.
98100
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
99101
constructor.
100102
@@ -104,28 +106,25 @@ def __init__(
104106
:class:`~sagemaker.estimator.Framework` and
105107
:class:`~sagemaker.estimator.EstimatorBase`.
106108
"""
107-
if framework_version is None:
109+
validate_version_or_image_args(framework_version, py_version, image_name)
110+
if py_version == "py2":
108111
logger.warning(
109-
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
112+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
110113
)
111-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
114+
self.framework_version = framework_version
115+
self.py_version = py_version
112116

113117
if "enable_sagemaker_metrics" not in kwargs:
114118
# enable sagemaker metrics for PT v1.3 or greater:
115-
if is_version_equal_or_higher([1, 3], self.framework_version):
119+
if self.framework_version and is_version_equal_or_higher(
120+
[1, 3], self.framework_version
121+
):
116122
kwargs["enable_sagemaker_metrics"] = True
117123

118124
super(PyTorch, self).__init__(
119125
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
120126
)
121127

122-
if py_version == "py2":
123-
logger.warning(
124-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
125-
)
126-
127-
self.py_version = py_version
128-
129128
def create_model(
130129
self,
131130
model_server_workers=None,
@@ -177,12 +176,12 @@ def create_model(
177176
self.model_data,
178177
role or self.role,
179178
entry_point or self.entry_point,
179+
framework_version=self.framework_version,
180+
py_version=self.py_version,
180181
source_dir=(source_dir or self._model_source_dir()),
181182
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
182183
container_log_level=self.container_log_level,
183184
code_location=self.code_location,
184-
py_version=self.py_version,
185-
framework_version=self.framework_version,
186185
model_server_workers=model_server_workers,
187186
sagemaker_session=self.sagemaker_session,
188187
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -210,15 +209,19 @@ class constructor
210209
image_name = init_params.pop("image")
211210
framework, py_version, tag, _ = framework_name_from_image(image_name)
212211

212+
if tag is None:
213+
framework_version = None
214+
else:
215+
framework_version = framework_version_from_tag(tag)
216+
init_params["framework_version"] = framework_version
217+
init_params["py_version"] = py_version
218+
213219
if not framework:
214220
# If we were unable to parse the framework name from the image it is not one of our
215221
# officially supported images, in this case just add the image to the init params.
216222
init_params["image_name"] = image_name
217223
return init_params
218224

219-
init_params["py_version"] = py_version
220-
init_params["framework_version"] = framework_version_from_tag(tag)
221-
222225
training_job_name = init_params["base_job_name"]
223226

224227
if framework != cls.__framework_name__:

src/sagemaker/pytorch/model.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
create_image_uri,
2222
model_code_key_prefix,
2323
python_deprecation_warning,
24-
empty_framework_version_warning,
24+
validate_version_or_image_args,
2525
)
2626
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2727
from sagemaker.pytorch import defaults
@@ -66,9 +66,9 @@ def __init__(
6666
model_data,
6767
role,
6868
entry_point,
69-
image=None,
70-
py_version=defaults.PYTHON_VERSION,
7169
framework_version=None,
70+
py_version=None,
71+
image=None,
7272
predictor_cls=PyTorchPredictor,
7373
model_server_workers=None,
7474
**kwargs
@@ -87,12 +87,16 @@ def __init__(
8787
file which should be executed as the entry point to model
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
90+
framework_version (str): PyTorch version you want to use for
91+
executing your model training code. Defaults to None.
92+
py_version (str): Python version you want to use for executing your
93+
model training code. Defaults to None.
9094
image (str): A Docker image URI (default: None). If not specified, a
9195
default image for PyTorch will be used.
92-
py_version (str): Python version you want to use for executing your
93-
model training code (default: 'py3').
94-
framework_version (str): PyTorch version you want to use for
95-
executing your model training code.
96+
97+
If ``framework_version`` or ``py_version`` are ``None``, then
98+
``image_name`` is required. If also ``None``, then a ``ValueError``
99+
will be raised.
96100
predictor_cls (callable[str, sagemaker.session.Session]): A function
97101
to call to create a predictor with an endpoint name and
98102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -109,22 +113,18 @@ def __init__(
109113
:class:`~sagemaker.model.FrameworkModel` and
110114
:class:`~sagemaker.model.Model`.
111115
"""
112-
super(PyTorchModel, self).__init__(
113-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
114-
)
115-
116-
if py_version == "py2":
116+
validate_version_or_image_args(framework_version, py_version, image)
117+
if py_version and py_version == "py2":
117118
logger.warning(
118119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
119120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
120123

121-
if framework_version is None:
122-
logger.warning(
123-
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
124-
)
124+
super(PyTorchModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
125127

126-
self.py_version = py_version
127-
self.framework_version = framework_version or defaults.PYTORCH_VERSION
128128
self.model_server_workers = model_server_workers
129129

130130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ def pytorch_version(request):
168168
return request.param
169169

170170

171+
@pytest.fixture(scope="module", params=["py2", "py3"])
172+
def pytorch_py_version(request):
173+
return request.param
174+
175+
171176
@pytest.fixture(scope="module", params=["0.20.0"])
172177
def sklearn_version(request):
173178
return request.param

tests/integ/test_airflow_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,8 @@ def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provid
613613
estimator = PyTorch(
614614
entry_point=PYTORCH_MNIST_SCRIPT,
615615
role=ROLE,
616-
framework_version="1.1.0",
616+
framework_version="1.3.1",
617+
framework_version="py3",
617618
train_instance_count=2,
618619
train_instance_type=cpu_instance_type,
619620
hyperparameters={"epochs": 6, "backend": "gloo"},
@@ -638,6 +639,7 @@ def test_pytorch_12_airflow_config_uploads_data_source_to_s3_when_inputs_not_pro
638639
entry_point=PYTORCH_MNIST_SCRIPT,
639640
role=ROLE,
640641
framework_version="1.2.0",
642+
py_version="py3",
641643
train_instance_count=2,
642644
train_instance_type=cpu_instance_type,
643645
hyperparameters={"epochs": 6, "backend": "gloo"},

tests/integ/test_pytorch_train.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
102102
PYTHON_VERSION == "py2",
103103
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
104104
)
105-
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
105+
def test_deploy_model(
106+
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
107+
):
106108
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
107109

108110
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -114,6 +116,8 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
114116
model_data,
115117
"SageMakerRole",
116118
entry_point=MNIST_SCRIPT,
119+
framework_version=pytorch_full_version,
120+
py_version="py3",
117121
sagemaker_session=sagemaker_session,
118122
)
119123
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -139,6 +143,7 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
139143
"SageMakerRole",
140144
entry_point="mnist.py",
141145
framework_version="1.4.0",
146+
py_version="py3",
142147
sagemaker_session=sagemaker_session,
143148
)
144149
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
@@ -160,8 +165,9 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
160165
pytorch = PyTorchModel(
161166
model_data,
162167
"SageMakerRole",
163-
framework_version="1.3.1",
164168
entry_point=EIA_SCRIPT,
169+
framework_version="1.3.1",
170+
py_version="py3",
165171
sagemaker_session=sagemaker_session,
166172
)
167173
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):

tests/integ/test_source_dirs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

2222
from sagemaker.pytorch.estimator import PyTorch
23+
from sagemaker.pytorch.defaults import PYTORCH_VERSION
2324

2425

2526
@pytest.mark.local_mode
@@ -35,6 +36,7 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3536
role="SageMakerRole",
3637
source_dir=source_dir,
3738
dependencies=[lib],
39+
framework_version=PYTORCH_VERSION,
3840
py_version=PYTHON_VERSION,
3941
train_instance_count=1,
4042
train_instance_type="local",

tests/integ/test_tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from sagemaker.mxnet.estimator import MXNet
3737
from sagemaker.predictor import json_deserializer
3838
from sagemaker.pytorch import PyTorch
39+
from sagemaker.pytorch.defaults import PYTORCH_VERSION
3940
from sagemaker.tensorflow import TensorFlow
4041
from sagemaker.tensorflow.defaults import LATEST_VERSION
4142
from sagemaker.tuner import (
@@ -827,6 +828,7 @@ def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
827828
entry_point=mnist_script,
828829
role="SageMakerRole",
829830
train_instance_count=1,
831+
framework_version=PYTORCH_VERSION,
830832
py_version=PYTHON_VERSION,
831833
train_instance_type=cpu_instance_type,
832834
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)