Skip to content

breaking: require framework_version, py_version for pytorch #1568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 11, 2020
Merged
6 changes: 4 additions & 2 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ directories ('train' and 'test').
pytorch_estimator = PyTorch('pytorch-train.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0',
framework_version='1.5.0',
py_version='py3',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
Expand Down Expand Up @@ -247,7 +248,8 @@ operation.
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0')
framework_version='1.5.0',
py_version='py3')
pytorch_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
Expand Down
17 changes: 17 additions & 0 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,20 @@ def _region_supports_debugger(region_name):

"""
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS


def validate_version_or_image_args(framework_version, py_version, image_name):
"""Checks if version or image arguments are specified.

Validates framework and model arguments to enforce version or image specification.

Args:
framework_version (str): The version of the framework.
py_version (str): The version of Python.
image_name (str): The URI of the image.
"""
if (framework_version is None or py_version is None) and image_name is None:
raise ValueError(
"framework_version or py_version was None, yet image_name was also None. "
"Either specify both framework_version and py_version, or specify image_name."
)
57 changes: 31 additions & 26 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
python_deprecation_warning,
is_version_equal_or_higher,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand All @@ -40,10 +40,10 @@ class PyTorch(Framework):
def __init__(
self,
entry_point,
framework_version=None,
py_version=None,
source_dir=None,
hyperparameters=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
image_name=None,
**kwargs
):
Expand All @@ -69,6 +69,13 @@ def __init__(
file which should be executed as the entry point to training.
If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): PyTorch version you want to use for
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
unless ``image_name`` is provided.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other training source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -80,12 +87,6 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3'). One of 'py2' or 'py3'.
framework_version (str): PyTorch version you want to use for
executing your model training code. List of supported versions
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
If not specified, this will default to 0.4.
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
Expand All @@ -95,6 +96,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.

Expand All @@ -104,28 +108,25 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.framework_version = framework_version
self.py_version = py_version

if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for PT v1.3 or greater:
if is_version_equal_or_higher([1, 3], self.framework_version):
if self.framework_version and is_version_equal_or_higher(
[1, 3], self.framework_version
):
kwargs["enable_sagemaker_metrics"] = True

super(PyTorch, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

self.py_version = py_version

def create_model(
self,
model_server_workers=None,
Expand Down Expand Up @@ -177,12 +178,12 @@ def create_model(
self.model_data,
role or self.role,
entry_point or self.entry_point,
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down Expand Up @@ -210,15 +211,19 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

if tag is None:
framework_version = None
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version
init_params["framework_version"] = framework_version_from_tag(tag)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
38 changes: 20 additions & 18 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
Expand Down Expand Up @@ -66,9 +66,9 @@ def __init__(
model_data,
role,
entry_point,
image=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
py_version=None,
image=None,
predictor_cls=PyTorchPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -87,12 +87,18 @@ def __init__(
file which should be executed as the entry point to model
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): PyTorch version you want to use for
executing your model training code. Defaults to None. Required
unless ``image`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
image (str): A Docker image URI (default: None). If not specified, a
default image for PyTorch will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3').
framework_version (str): PyTorch version you want to use for
executing your model training code.

If ``framework_version`` or ``py_version`` are ``None``, then
``image`` is required. If also ``None``, then a ``ValueError``
will be raised.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -109,22 +115,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

if py_version == "py2":
validate_version_or_image_args(framework_version, py_version, image)
if py_version and py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
)
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def pytorch_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def pytorch_py_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.20.0"])
def sklearn_version(request):
return request.param
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provid
estimator = PyTorch(
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.1.0",
framework_version="1.3.1",
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand All @@ -638,6 +639,7 @@ def test_pytorch_12_airflow_config_uploads_data_source_to_s3_when_inputs_not_pro
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.2.0",
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand Down
18 changes: 8 additions & 10 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
predictor.delete_endpoint()


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
def test_deploy_model(
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand All @@ -114,6 +112,8 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
model_data,
"SageMakerRole",
entry_point=MNIST_SCRIPT,
framework_version=pytorch_full_version,
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -125,10 +125,6 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
assert output.shape == (batch_size, 10)


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

Expand All @@ -139,6 +135,7 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
"SageMakerRole",
entry_point="mnist.py",
framework_version="1.4.0",
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -160,8 +157,9 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
pytorch = PyTorchModel(
model_data,
"SageMakerRole",
framework_version="1.3.1",
entry_point=EIA_SCRIPT,
framework_version="1.3.1",
py_version="py3",
sagemaker_session=sagemaker_session,
)
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_source_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.local_mode
def test_source_dirs(tmpdir, sagemaker_local_session):
def test_source_dirs(tmpdir, sagemaker_local_session, pytorch_full_version):
source_dir = os.path.join(DATA_DIR, "pytorch_source_dirs")
lib = os.path.join(str(tmpdir), "alexa.py")

Expand All @@ -35,6 +35,7 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
role="SageMakerRole",
source_dir=source_dir,
dependencies=[lib],
framework_version=pytorch_full_version,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.mxnet.estimator import MXNet
from sagemaker.predictor import json_deserializer
from sagemaker.pytorch import PyTorch
from sagemaker.pytorch.defaults import PYTORCH_VERSION
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.defaults import LATEST_VERSION
from sagemaker.tuner import (
Expand Down Expand Up @@ -827,6 +828,7 @@ def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
entry_point=mnist_script,
role="SageMakerRole",
train_instance_count=1,
framework_version=PYTORCH_VERSION,
py_version=PYTHON_VERSION,
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
Expand Down
Loading