Skip to content

feature: Add support for PyTorch 1.2 #1025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ PyTorch SageMaker Estimators

With PyTorch SageMaker Estimators, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.

We recommend that you use the latest supported version, because that's where we focus most of our development efforts.

Expand Down
23 changes: 22 additions & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"tensorflow-serving-eia": "tensorflow-inference-eia",
"mxnet": "mxnet-training",
"mxnet-serving": "mxnet-inference",
"pytorch": "pytorch-training",
"pytorch-serving": "pytorch-inference",
"mxnet-serving-eia": "mxnet-inference-eia",
}

Expand All @@ -76,6 +78,8 @@
"tensorflow-serving-eia": [1, 14, 0],
"mxnet": [1, 4, 1],
"mxnet-serving": [1, 4, 1],
"pytorch": [1, 2, 0],
"pytorch-serving": [1, 2, 0],
"mxnet-serving-eia": [1, 4, 1],
}

Expand Down Expand Up @@ -119,10 +123,15 @@ def _using_merged_images(region, framework, py_version, framework_version):
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
is_py3 = py_version == "py3" or py_version is None
is_merged_versions = _is_merged_versions(framework, framework_version)

return (
((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION)
and is_merged_versions
and (is_py3 or _is_tf_14_or_later(framework, framework_version))
and (
is_py3
or _is_tf_14_or_later(framework, framework_version)
or _is_pt_12_or_later(framework, framework_version)
)
)


Expand All @@ -140,6 +149,18 @@ def _is_tf_14_or_later(framework, framework_version):
)


def _is_pt_12_or_later(framework, framework_version):
"""
Args:
framework: Name of the frameowork
framework_version: framework version
"""
asimov_lowest_pt = [1, 12, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should gov region be handled here? We can do this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the map above is filled out, it should work fine. not sure why an extra method is needed...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lauren pointed this out, but this should be [1,2,0] and not [1,12,0]

version = [int(s) for s in framework_version.split(".")]
is_pytorch = framework in ("pytorch", "pytorch-serving")
return is_pytorch and version >= asimov_lowest_pt[0 : len(version)]


def _registry_id(region, framework, py_version, account, framework_version):
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SageMaker PyTorch Estimators and Models

With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker.

Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``.
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``.

We recommend that you use the latest supported version, because that's where we focus most of our development efforts.

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PyTorch(Framework):

__framework_name__ = "pytorch"

LATEST_VERSION = "1.1"
LATEST_VERSION = "1.2.0"
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""

def __init__(
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import logging
import pkg_resources

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
Expand Down Expand Up @@ -53,6 +54,7 @@ class PyTorchModel(FrameworkModel):
"""

__framework_name__ = "pytorch"
_LOWEST_MMS_VERSION = "1.2"

def __init__(
self,
Expand Down Expand Up @@ -122,19 +124,28 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
dict[str, str]: A container definition object usable with the
CreateModel API.
"""
lowest_mms_version = pkg_resources.parse_version(self._LOWEST_MMS_VERSION)
framework_version = pkg_resources.parse_version(self.framework_version)
is_mms_version = framework_version >= lowest_mms_version

deploy_image = self.image
if not deploy_image:
region_name = self.sagemaker_session.boto_session.region_name

framework_name = self.__framework_name__
if is_mms_version:
framework_name += "-serving"

deploy_image = create_image_uri(
region_name,
self.__framework_name__,
framework_name,
instance_type,
self.framework_version,
self.py_version,
accelerator_type=accelerator_type,
)
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
self._upload_code(deploy_key_prefix, repack=is_mms_version)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())

Expand Down
25 changes: 25 additions & 0 deletions tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
assert output.shape == (batch_size, 10)


@pytest.mark.local_mode
def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
pytorch = PyTorch(
entry_point=MNIST_SCRIPT,
role="SageMakerRole",
framework_version=pytorch_full_version,
py_version="py3",
train_instance_count=1,
train_instance_type="local",
sagemaker_session=sagemaker_local_session,
)

pytorch.fit({"training": "file://" + os.path.join(MNIST_DIR, "training")})

predictor = pytorch.deploy(1, "local")
try:
batch_size = 100
data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32)
output = predictor.predict(data)

assert output.shape == (batch_size, 10)
finally:
predictor.delete_endpoint()


def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())

Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,32 @@ def test_create_image_uri_merged_gov_regions():
)


def test_create_image_uri_merged_pytorch():

image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.12", "py2")
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.12-gpu-py2"

image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.11", "py2")
assert (
image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:1.11-gpu-py2"
)

image_uri = fw_utils.create_image_uri(
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.12", "py2"
)
assert (
image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12-cpu-py2"
)

image_uri = fw_utils.create_image_uri(
"us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.11", "py2"
)
assert (
image_uri
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch-serving:1.11-cpu-py2"
)


def test_create_image_uri_accelerator_tf():
image_uri = fw_utils.create_image_uri(
MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"
Expand Down
39 changes: 37 additions & 2 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import os
import pytest
import sys
from mock import MagicMock, Mock
from mock import patch
from mock import ANY, MagicMock, Mock, patch

from sagemaker.pytorch import defaults
from sagemaker.pytorch import PyTorch
Expand Down Expand Up @@ -296,6 +295,42 @@ def test_model(sagemaker_session):
assert isinstance(predictor, PyTorchPredictor)


@patch("sagemaker.utils.create_tar_file", MagicMock())
@patch("sagemaker.utils.repack_model")
def test_mms_model(repack_model, sagemaker_session):
PyTorchModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version="1.2",
).deploy(1, GPU)

repack_model.assert_called_with(
dependencies=[],
inference_script=SCRIPT_PATH,
kms_key=None,
model_uri="s3://some/data.tar.gz",
repacked_model_uri=ANY,
sagemaker_session=sagemaker_session,
source_directory=None,
)


@patch("sagemaker.utils.create_tar_file", MagicMock())
@patch("sagemaker.utils.repack_model")
def test_non_mms_model(repack_model, sagemaker_session):
PyTorchModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version="1.1",
).deploy(1, GPU)

repack_model.assert_not_called()


@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
def test_model_image_accelerator(sagemaker_session):
model = PyTorchModel(
Expand Down