Skip to content

feature: Add support for PyTorch 1.2.0 #1091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 15, 2019
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ PyTorch SageMaker Estimators

With PyTorch SageMaker Estimators, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


We recommend that you use the latest supported version, because that's where we focus most of our development efforts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, the table at the bottom of this README should be updated (feel free to do in a separate PR)

Expand Down
24 changes: 23 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"tensorflow-serving-eia": "tensorflow-inference-eia",
"mxnet": "mxnet-training",
"mxnet-serving": "mxnet-inference",
"pytorch": "pytorch-training",
"pytorch-serving": "pytorch-inference",
"mxnet-serving-eia": "mxnet-inference-eia",
}

Expand All @@ -76,6 +78,8 @@
"tensorflow-serving-eia": [1, 14, 0],
"mxnet": [1, 4, 1],
"mxnet-serving": [1, 4, 1],
"pytorch": [1, 2, 0],
"pytorch-serving": [1, 2, 0],
"mxnet-serving-eia": [1, 4, 1],
}

Expand Down Expand Up @@ -119,10 +123,15 @@ def _using_merged_images(region, framework, py_version, framework_version):
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_py3 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)

return (
((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION)
and is_merged_versions
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
and (
is_py3
or _is_tf_14_or_later(framework, framework_version)
or _is_pt_12_or_later(framework, framework_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in scope here, but I think it would be cleaner to maintain a constant that keeps track of which frameworks have both Python 2 & 3 support versus just Python 3 rather than writing new methods for each framework. maybe let is_merged_versions take in the Python version as another parameter.

)
)


Expand All @@ -140,6 +149,19 @@ def _is_tf_14_or_later(framework, framework_version):
)


def _is_pt_12_or_later(framework, framework_version):
"""
Args:
framework: Name of the frameowork
framework_version: framework version
"""
# Asimov team now owns PyTorch 1.2.0 py2 and py3
asimov_lowest_pt = [1, 2, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "asimov" is an internal name, and probably shouldn't be used here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

version = [int(s) for s in framework_version.split(".")]
is_pytorch = framework in ("pytorch", "pytorch-serving")
return is_pytorch and version >= asimov_lowest_pt[0 : len(version)]


def _registry_id(region, framework, py_version, account, framework_version):
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SageMaker PyTorch Estimators and Models

With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.

We recommend that you use the latest supported version, because that's where we focus most of our development efforts.

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PyTorch(Framework):

__framework_name__ = "pytorch"

LATEST_VERSION = "1.1"
LATEST_VERSION = "1.2.0"
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""

def __init__(
Expand Down
19 changes: 16 additions & 3 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import logging
import pkg_resources

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
Expand Down Expand Up @@ -53,6 +54,7 @@ class PyTorchModel(FrameworkModel):
"""

__framework_name__ = "pytorch"
_LOWEST_MMS_VERSION = "1.2"

def __init__(
self,
Expand Down Expand Up @@ -122,22 +124,33 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
dict[str, str]: A container definition object usable with the
CreateModel API.
"""
lowest_mms_version = pkg_resources.parse_version(self._LOWEST_MMS_VERSION)
framework_version = pkg_resources.parse_version(self.framework_version)
is_mms_version = framework_version >= lowest_mms_version

deploy_image = self.image
if not deploy_image:
region_name = self.sagemaker_session.boto_session.region_name

framework_name = self.__framework_name__
if is_mms_version:
framework_name += "-serving"

deploy_image = create_image_uri(
region_name,
self.__framework_name__,
framework_name,
instance_type,
self.framework_version,
self.py_version,
accelerator_type=accelerator_type,
)
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
self._upload_code(deploy_key_prefix, repack=is_mms_version)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
return sagemaker.container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
)
25 changes: 25 additions & 0 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
assert output.shape == (batch_size, 10)


@pytest.mark.local_mode
def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
pytorch = PyTorch(
entry_point=MNIST_SCRIPT,
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version="py3",
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_local_session,
)

pytorch.fit({"training": "file://" + os.path.join(MNIST_DIR, "training")})

predictor = pytorch.deploy(1, "local")
try:
batch_size = 100
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
output = predictor.predict(data)

assert output.shape == (batch_size, 10)
finally:
predictor.delete_endpoint()


def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,28 @@ def test_create_image_uri_merged_gov_regions():
)


def test_create_image_uri_merged_pytorch():

image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.2", "py2")
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.2-gpu-py2"

image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.1", "py2")
assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:1.1-gpu-py2"

image_uri = fw_utils.create_image_uri(
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.2", "py2"
)
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.2-cpu-py2"

image_uri = fw_utils.create_image_uri(
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.1", "py2"
)
assert (
image_uri
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch-serving:1.1-cpu-py2"
)


def test_create_image_uri_accelerator_tf():
image_uri = fw_utils.create_image_uri(
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
Expand Down
39 changes: 37 additions & 2 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import os
import pytest
import sys
from mock import MagicMock, Mock
from mock import patch
from mock import ANY, MagicMock, Mock, patch

from sagemaker.pytorch import defaults
from sagemaker.pytorch import PyTorch
Expand Down Expand Up @@ -296,6 +295,42 @@ def test_model(sagemaker_session):
assert isinstance(predictor, PyTorchPredictor)


@patch("sagemaker.utils.create_tar_file", MagicMock())
@patch("sagemaker.utils.repack_model")
def test_mms_model(repack_model, sagemaker_session):
PyTorchModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version="1.2",
).deploy(1, GPU)

repack_model.assert_called_with(
dependencies=[],
inference_script=SCRIPT_PATH,
kms_key=None,
model_uri="s3://some/data.tar.gz",
repacked_model_uri=ANY,
sagemaker_session=sagemaker_session,
source_directory=None,
)


@patch("sagemaker.utils.create_tar_file", MagicMock())
@patch("sagemaker.utils.repack_model")
def test_non_mms_model(repack_model, sagemaker_session):
PyTorchModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version="1.1",
).deploy(1, GPU)

repack_model.assert_not_called()


@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
def test_model_image_accelerator(sagemaker_session):
model = PyTorchModel(
Expand Down