Skip to content

fix: add default framework version warning message in Model classes #1218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker/chainer/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
"""Default Chainer version for when the framework version is not specified.
This is no longer updated so as to not break existing workflows.
"""

LATEST_VERSION = "5.0.0"
"""The latest version of Chainer included in the SageMaker pre-built Docker images."""
5 changes: 2 additions & 3 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
empty_framework_version_warning,
python_deprecation_warning,
)
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION
from sagemaker.chainer.model import ChainerModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

Expand All @@ -40,8 +40,7 @@ class Chainer(Framework):
_process_slots_per_host = "sagemaker_process_slots_per_host"
_additional_mpi_options = "sagemaker_additional_mpi_options"

LATEST_VERSION = "5.0.0"
"""The latest version of Chainer included in the SageMaker pre-built Docker images."""
LATEST_VERSION = LATEST_VERSION

def __init__(
self,
Expand Down
15 changes: 11 additions & 4 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer.defaults import CHAINER_VERSION
from sagemaker.chainer.defaults import CHAINER_VERSION, LATEST_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -61,7 +66,7 @@ def __init__(
entry_point,
image=None,
py_version="py3",
framework_version=CHAINER_VERSION,
framework_version=None,
predictor_cls=ChainerPredictor,
model_server_workers=None,
**kwargs
Expand Down Expand Up @@ -107,9 +112,11 @@ def __init__(
)
if py_version == "py2":
logger.warning(python_deprecation_warning(self.__framework_name__))
if framework_version is None:
logger.warning(empty_framework_version_warning(CHAINER_VERSION, LATEST_VERSION))

self.py_version = py_version
self.framework_version = framework_version
self.framework_version = framework_version or CHAINER_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/mxnet/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
"""Default MXNet version for when the framework version is not specified.
This is no longer updated so as to not break existing workflows.
"""

LATEST_VERSION = "1.6.0"
"""The latest version of MXNet included in the SageMaker pre-built Docker images."""
5 changes: 2 additions & 3 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
python_deprecation_warning,
is_version_equal_or_higher,
)
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION
from sagemaker.mxnet.model import MXNetModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

Expand All @@ -36,8 +36,7 @@ class MXNet(Framework):
__framework_name__ = "mxnet"
_LOWEST_SCRIPT_MODE_VERSION = ["1", "3"]

LATEST_VERSION = "1.6.0"
"""The latest version of MXNet included in the SageMaker pre-built Docker images."""
LATEST_VERSION = LATEST_VERSION

def __init__(
self,
Expand Down
15 changes: 11 additions & 4 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
from pkg_resources import parse_version

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.mxnet.defaults import MXNET_VERSION, LATEST_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer

logger = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -62,7 +67,7 @@ def __init__(
entry_point,
image=None,
py_version="py2",
framework_version=MXNET_VERSION,
framework_version=None,
predictor_cls=MXNetPredictor,
model_server_workers=None,
**kwargs
Expand Down Expand Up @@ -109,9 +114,11 @@ def __init__(

if py_version == "py2":
logger.warning(python_deprecation_warning(self.__framework_name__))
if framework_version is None:
logger.warning(empty_framework_version_warning(MXNET_VERSION, LATEST_VERSION))

self.py_version = py_version
self.framework_version = framework_version
self.framework_version = framework_version or MXNET_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/pytorch/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@
break existing workflows.
"""

LATEST_VERSION = "1.3.1"
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""

PYTHON_VERSION = "py3"
5 changes: 2 additions & 3 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
python_deprecation_warning,
is_version_equal_or_higher,
)
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

Expand All @@ -35,8 +35,7 @@ class PyTorch(Framework):

__framework_name__ = "pytorch"

LATEST_VERSION = "1.3.1"
"""The latest version of PyTorch included in the SageMaker pre-built Docker images."""
LATEST_VERSION = LATEST_VERSION

def __init__(
self,
Expand Down
15 changes: 11 additions & 4 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
import pkg_resources

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer

logger = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -63,7 +68,7 @@ def __init__(
entry_point,
image=None,
py_version=PYTHON_VERSION,
framework_version=PYTORCH_VERSION,
framework_version=None,
predictor_cls=PyTorchPredictor,
model_server_workers=None,
**kwargs
Expand Down Expand Up @@ -110,9 +115,11 @@ def __init__(

if py_version == "py2":
logger.warning(python_deprecation_warning(self.__framework_name__))
if framework_version is None:
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION))

self.py_version = py_version
self.framework_version = framework_version
self.framework_version = framework_version or PYTORCH_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/tensorflow/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
"""Default TF version for when the framework version is not specified.
This is no longer updated so as to not break existing workflows.
"""

LATEST_VERSION = "2.0.0"
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
5 changes: 2 additions & 3 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sagemaker.debugger import DebuggerHookConfig
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.tensorflow.serving import Model
from sagemaker.transformer import Transformer
Expand Down Expand Up @@ -197,8 +197,7 @@ class TensorFlow(Framework):

__framework_name__ = "tensorflow"

LATEST_VERSION = "2.0.0"
"""The latest version of TensorFlow included in the SageMaker pre-built Docker images."""
LATEST_VERSION = LATEST_VERSION

_LATEST_1X_VERSION = "1.15.0"

Expand Down
15 changes: 11 additions & 4 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
import logging

import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.fw_utils import (
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.defaults import TF_VERSION, LATEST_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer

logger = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -60,7 +65,7 @@ def __init__(
entry_point,
image=None,
py_version="py2",
framework_version=TF_VERSION,
framework_version=None,
predictor_cls=TensorFlowPredictor,
model_server_workers=None,
**kwargs
Expand Down Expand Up @@ -107,9 +112,11 @@ def __init__(

if py_version == "py2":
logger.warning(python_deprecation_warning(self.__framework_name__))
if framework_version is None:
logger.warning(empty_framework_version_warning(TF_VERSION, LATEST_VERSION))

self.py_version = py_version
self.framework_version = framework_version
self.framework_version = framework_version or TF_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,16 @@ def test_empty_framework_version(warning, sagemaker_session):

assert estimator.framework_version == defaults.CHAINER_VERSION
warning.assert_called_with(defaults.CHAINER_VERSION, Chainer.LATEST_VERSION)


@patch("sagemaker.chainer.model.empty_framework_version_warning")
def test_model_empty_framework_version(warning, sagemaker_session):
model = ChainerModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=None,
)
assert model.framework_version == defaults.CHAINER_VERSION
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
13 changes: 13 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,19 @@ def test_empty_framework_version(warning, sagemaker_session):
warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION)


@patch("sagemaker.mxnet.model.empty_framework_version_warning")
def test_model_empty_framework_version(warning, sagemaker_session):
model = MXNetModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=None,
)
assert model.framework_version == defaults.MXNET_VERSION
warning.assert_called_with(defaults.MXNET_VERSION, defaults.LATEST_VERSION)


def test_create_model_with_custom_hosting_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,20 @@ def test_empty_framework_version(warning, sagemaker_session):
warning.assert_called_with(defaults.PYTORCH_VERSION, estimator.LATEST_VERSION)


@patch("sagemaker.pytorch.model.empty_framework_version_warning")
def test_model_empty_framework_version(warning, sagemaker_session):
model = PyTorchModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=None,
)

assert model.framework_version == defaults.PYTORCH_VERSION
warning.assert_called_with(defaults.PYTORCH_VERSION, defaults.LATEST_VERSION)


def test_pt_enable_sm_metrics(sagemaker_session):
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=True)
assert pytorch.enable_sagemaker_metrics
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,17 @@ def test_empty_framework_version(warning, sagemaker_session):
assert estimator.framework_version == defaults.TF_VERSION
warning.assert_called_with(defaults.TF_VERSION, estimator.LATEST_VERSION)

model = TensorFlowModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
sagemaker_session=sagemaker_session,
framework_version=None,
)

assert model.framework_version == defaults.TF_VERSION
warning.assert_called_with(defaults.TF_VERSION, defaults.LATEST_VERSION)


def _deprecated_args_msg(args):
return "{} are deprecated in script mode. Please do not set {}.".format(
Expand Down