Skip to content

breaking: require framework_version, py_version for pytorch #1568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 11, 2020
6 changes: 4 additions & 2 deletions doc/frameworks/pytorch/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ directories ('train' and 'test').
pytorch_estimator = PyTorch('pytorch-train.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0',
framework_version='1.5.0',
py_version='py3',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
Expand Down Expand Up @@ -247,7 +248,8 @@ operation.
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='1.0.0')
framework_version='1.5.0',
py_version='py3')
pytorch_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
Expand Down
57 changes: 31 additions & 26 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
python_deprecation_warning,
is_version_equal_or_higher,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
Expand All @@ -40,10 +40,10 @@ class PyTorch(Framework):
def __init__(
self,
entry_point,
framework_version=None,
py_version=None,
source_dir=None,
hyperparameters=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
image_name=None,
**kwargs
):
Expand All @@ -69,6 +69,13 @@ def __init__(
file which should be executed as the entry point to training.
If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): PyTorch version you want to use for
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
unless ``image_name`` is provided.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other training source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -80,12 +87,6 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3'). One of 'py2' or 'py3'.
framework_version (str): PyTorch version you want to use for
executing your model training code. List of supported versions
https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators.
If not specified, this will default to 0.4.
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
Expand All @@ -95,6 +96,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.

Expand All @@ -104,28 +108,25 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.framework_version = framework_version
self.py_version = py_version

if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for PT v1.3 or greater:
if is_version_equal_or_higher([1, 3], self.framework_version):
if self.framework_version and is_version_equal_or_higher(
[1, 3], self.framework_version
):
kwargs["enable_sagemaker_metrics"] = True

super(PyTorch, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

self.py_version = py_version

def create_model(
self,
model_server_workers=None,
Expand Down Expand Up @@ -177,12 +178,12 @@ def create_model(
self.model_data,
role or self.role,
entry_point or self.entry_point,
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
code_location=self.code_location,
py_version=self.py_version,
framework_version=self.framework_version,
model_server_workers=model_server_workers,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
Expand Down Expand Up @@ -210,15 +211,19 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

if tag is None:
framework_version = None
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version
init_params["framework_version"] = framework_version_from_tag(tag)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
38 changes: 20 additions & 18 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
Expand Down Expand Up @@ -66,9 +66,9 @@ def __init__(
model_data,
role,
entry_point,
image=None,
py_version=defaults.PYTHON_VERSION,
framework_version=None,
py_version=None,
image=None,
predictor_cls=PyTorchPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -87,12 +87,18 @@ def __init__(
file which should be executed as the entry point to model
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
framework_version (str): PyTorch version you want to use for
executing your model training code. Defaults to None. Required
unless ``image`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
image (str): A Docker image URI (default: None). If not specified, a
default image for PyTorch will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py3').
framework_version (str): PyTorch version you want to use for
executing your model training code.

If ``framework_version`` or ``py_version`` are ``None``, then
``image`` is required. If also ``None``, then a ``ValueError``
will be raised.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -109,22 +115,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

if py_version == "py2":
validate_version_or_image_args(framework_version, py_version, image)
if py_version and py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)
)
super(PyTorchModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.PYTORCH_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def pytorch_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def pytorch_py_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.20.0"])
def sklearn_version(request):
return request.param
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provid
estimator = PyTorch(
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.1.0",
framework_version="1.3.1",
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand All @@ -639,6 +640,7 @@ def test_pytorch_12_airflow_config_uploads_data_source_to_s3_when_inputs_not_pro
entry_point=PYTORCH_MNIST_SCRIPT,
role=ROLE,
framework_version="1.2.0",
py_version="py3",
train_instance_count=2,
train_instance_type=cpu_instance_type,
hyperparameters={"epochs": 6, "backend": "gloo"},
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def test_github(sagemaker_local_session):
script_path = "mnist.py"
data_path = os.path.join(DATA_DIR, "pytorch_mnist")
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}

# TODO: fails for newer pytorch versions when using MNIST from torchvision due to missing dataset
# "algo-1-v767u_1 | RuntimeError: Dataset not found. You can use download=True to download it"
pytorch = PyTorch(
entry_point=script_path,
role="SageMakerRole",
Expand Down
20 changes: 9 additions & 11 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
predictor.delete_endpoint()


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
def test_deploy_model(
pytorch_training_job, sagemaker_session, cpu_instance_type, pytorch_full_version
):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand All @@ -114,6 +112,8 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
model_data,
"SageMakerRole",
entry_point=MNIST_SCRIPT,
framework_version=pytorch_full_version,
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -125,10 +125,6 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type
assert output.shape == (batch_size, 10)


@pytest.mark.skipif(
PYTHON_VERSION == "py2",
reason="Python 2 is supported by PyTorch {} and lower versions.".format(LATEST_PY2_VERSION),
)
def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

Expand All @@ -139,6 +135,7 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
"SageMakerRole",
entry_point="mnist.py",
framework_version="1.4.0",
py_version="py3",
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
Expand All @@ -160,8 +157,9 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
pytorch = PyTorchModel(
model_data,
"SageMakerRole",
framework_version="1.3.1",
entry_point=EIA_SCRIPT,
framework_version="1.3.1",
py_version="py3",
sagemaker_session=sagemaker_session,
)
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
Expand Down Expand Up @@ -193,7 +191,7 @@ def _get_pytorch_estimator(
entry_point=entry_point,
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version=PYTHON_VERSION,
py_version="py3",
train_instance_count=1,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
Expand Down
4 changes: 4 additions & 0 deletions tests/integ/test_source_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tests.integ.lock as lock
from tests.integ import DATA_DIR, PYTHON_VERSION

from sagemaker.pytorch.defaults import PYTORCH_VERSION
from sagemaker.pytorch.estimator import PyTorch


Expand All @@ -30,11 +31,14 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
with open(lib, "w") as f:
f.write("def question(to_anything): return 42")

# TODO: fails on newer versions of pytorch in call to np.load(BytesIO(stream.read()))
# "ValueError: Cannot load file containing pickled data when allow_pickle=False"
estimator = PyTorch(
entry_point="train.py",
role="SageMakerRole",
source_dir=source_dir,
dependencies=[lib],
framework_version=PYTORCH_VERSION,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type="local",
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_transform_pytorch_vpc_custom_model_bucket(
entry_point=os.path.join(data_dir, "mnist.py"),
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version=PYTHON_VERSION,
py_version="py3",
sagemaker_session=sagemaker_session,
vpc_config={"Subnets": subnet_ids, "SecurityGroupIds": [security_group_id]},
code_location="s3://{}".format(custom_bucket_name),
Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,15 +819,16 @@ def test_tuning_chainer(sagemaker_session, cpu_instance_type):
reason="This test has always failed, but the failure was masked by a bug. "
"This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
)
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type):
def test_attach_tuning_pytorch(sagemaker_session, cpu_instance_type, pytorch_full_version):
mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist")
mnist_script = os.path.join(mnist_dir, "mnist.py")

estimator = PyTorch(
entry_point=mnist_script,
role="SageMakerRole",
train_instance_count=1,
py_version=PYTHON_VERSION,
framework_version=pytorch_full_version,
py_version="py3",
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
)
Expand Down
Loading