Skip to content

breaking: create new inference resources during estimator.deploy() or estimator.transformer() #1639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 30, 2020
5 changes: 5 additions & 0 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,11 @@ def _is_marketplace(self):
"""Placeholder docstring"""
return "ProductId" in self.algorithm_spec

def _ensure_base_job_name(self):
"""Set ``self.base_job_name`` if it is not set already."""
if self.base_job_name is None:
self.base_job_name = self.algorithm_arn.split("/")[-1]

def _prepare_for_training(self, job_name=None):
# Validate hyperparameters
# an explicit call to set_hyperparameters() will also validate the hyperparameters
Expand Down
5 changes: 2 additions & 3 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,11 @@ def create_model(
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
"""
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return ChainerModel(
self.model_data,
role or self.role,
Expand Down
61 changes: 38 additions & 23 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
training output (default: None).
base_job_name (str): Prefix for training job name when the
:meth:`~sagemaker.estimator.EstimatorBase.fit` method launches.
If not specified, the estimator generates a default job name,
If not specified, the estimator generates a default job name
based on the training image name and current timestamp.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
Expand Down Expand Up @@ -328,6 +328,28 @@ def prepare_workflow_for_training(self, job_name=None):
"""
self._prepare_for_training(job_name=job_name)

def _ensure_base_job_name(self):
"""Set ``self.base_job_name`` if it is not set already."""
# honor supplied base_job_name or generate it
if self.base_job_name is None:
self.base_job_name = base_name_from_image(self.train_image())

def _get_or_create_name(self, name=None):
"""Generate a name based on the base job name or training image if needed.

Args:
name (str): User-supplied name. If not specified, a name is generated from
the base job name or training image.

Returns:
str: Either the user-supplied name or a generated name.
"""
if name:
return name

self._ensure_base_job_name()
return name_from_base(self.base_job_name)

def _prepare_for_training(self, job_name=None):
"""Set any values in the estimator that need to be set before training.

Expand All @@ -336,18 +358,7 @@ def _prepare_for_training(self, job_name=None):
specified, one is generated, using the base name given to the
constructor if applicable.
"""
if job_name is not None:
self._current_job_name = job_name
else:
# honor supplied base_job_name or generate it
if self.base_job_name:
base_name = self.base_job_name
elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
base_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
else:
base_name = base_name_from_image(self.train_image())

self._current_job_name = name_from_base(base_name)
self._current_job_name = self._get_or_create_name(job_name)

# if output_path was specified we use it otherwise initialize here.
# For Local Mode with local_code=True we don't need an explicit output_path
Expand Down Expand Up @@ -483,7 +494,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
compatibility, boolean values are also accepted and converted to strings.
Only meaningful when wait is True.
job_name (str): Training job name. If not specified, the estimator generates
a default job name, based on the training image name and current timestamp.
a default job name based on the training image name and current timestamp.
experiment_config (dict[str, str]): Experiment management configuration.
Dictionary contains three optional keys,
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
Expand Down Expand Up @@ -667,7 +678,8 @@ def deploy(
wait (bool): Whether the call should wait until the deployment of
model completes (default: True).
model_name (str): Name to use for creating an Amazon SageMaker
model. If not specified, the name of the training job is used.
model. If not specified, the estimator generates a default job name
based on the training image name and current timestamp.
kms_key (str): The ARN of the KMS key that is used to encrypt the
data on the storage volume attached to the instance hosting the
endpoint.
Expand All @@ -691,8 +703,11 @@ def deploy(
endpoint and obtain inferences.
"""
self._ensure_latest_training_job()
endpoint_name = endpoint_name or self.latest_training_job.name
model_name = model_name or self.latest_training_job.name
self._ensure_base_job_name()
default_name = name_from_base(self.base_job_name)
endpoint_name = endpoint_name or default_name
model_name = model_name or default_name

self.deploy_instance_type = instance_type
if use_compiled_model:
family = "_".join(instance_type.split(".")[:-1])
Expand Down Expand Up @@ -889,18 +904,18 @@ def transformer(
If not specified, this setting is taken from the estimator's
current configuration.
model_name (str): Name to use for creating an Amazon SageMaker
model. If not specified, the name of the training job is used.
model. If not specified, the estimator generates a default job name
based on the training image name and current timestamp.
"""
tags = tags or self.tags
model_name = self._get_or_create_name(model_name)

if self.latest_training_job is None:
logging.warning(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_name = model_name or self._current_job_name
else:
model_name = model_name or self.latest_training_job.name
if enable_network_isolation is None:
enable_network_isolation = self.enable_network_isolation()

Expand Down Expand Up @@ -1984,14 +1999,16 @@ def transformer(
If not specified, this setting is taken from the estimator's
current configuration.
model_name (str): Name to use for creating an Amazon SageMaker
model. If not specified, the name of the training job is used.
model. If not specified, the estimator generates a default job name
based on the training image name and current timestamp.

Returns:
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
SageMaker Batch Transform job.
"""
role = role or self.role
tags = tags or self.tags
model_name = self._get_or_create_name(model_name)

if self.latest_training_job is not None:
if enable_network_isolation is None:
Expand All @@ -2008,7 +2025,6 @@ def transformer(
)
model._create_sagemaker_model(instance_type, tags=tags)

model_name = model.name
transform_env = model.env.copy()
if env is not None:
transform_env.update(env)
Expand All @@ -2017,7 +2033,6 @@ def transformer(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_name = model_name or self._current_job_name
transform_env = env or {}

return Transformer(
Expand Down
3 changes: 1 addition & 2 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ def create_model(
if "image" not in kwargs:
kwargs["image"] = image_name or self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

return MXNetModel(
self.model_data,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _get_model_names(self):
EndpointConfigName=self._endpoint_config_name
)
production_variants = endpoint_config["ProductionVariants"]
return map(lambda d: d["ModelName"], production_variants)
return [d["ModelName"] for d in production_variants]


class _CsvSerializer(object):
Expand Down
3 changes: 1 addition & 2 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def create_model(
if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

return PyTorchModel(
self.model_data,
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import logging
import re

from sagemaker import fw_utils
from sagemaker.estimator import Framework
import sagemaker.fw_utils as fw_utils
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
from sagemaker.mxnet.model import MXNetModel
from sagemaker.tensorflow.model import TensorFlowModel
Expand Down Expand Up @@ -222,12 +222,13 @@ def create_model(
model_data=self.model_data,
role=role or self.role,
image=kwargs.get("image", self.image_name),
name=kwargs.get("name", self._current_job_name),
container_log_level=self.container_log_level,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
)

base_args["name"] = self._get_or_create_name(kwargs.get("name"))

if not entry_point and (source_dir or dependencies):
raise AttributeError("Please provide an `entry_point`.")

Expand Down
4 changes: 1 addition & 3 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,14 @@ def create_model(
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
"""
role = role or self.role
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_name

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return SKLearnModel(
self.model_data,
role,
Expand Down
11 changes: 6 additions & 5 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,11 @@ def create_model(
sagemaker.tensorflow.model.TensorFlowModel: A ``TensorFlowModel`` object.
See :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
"""
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

if "enable_network_isolation" not in kwargs:
kwargs["enable_network_isolation"] = self.enable_network_isolation()

Expand Down Expand Up @@ -440,17 +439,19 @@ def transformer(
If not specified, this setting is taken from the estimator's
current configuration.
model_name (str): Name to use for creating an Amazon SageMaker
model. If not specified, the name of the training job is used.
model. If not specified, the estimator generates a default job name
based on the training image name and current timestamp.
"""
role = role or self.role
model_name = self._get_or_create_name(model_name)

if self.latest_training_job is None:
logging.warning(
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
return Transformer(
model_name or self._current_job_name,
model_name,
instance_count,
instance_type,
strategy=strategy,
Expand Down
4 changes: 1 addition & 3 deletions src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,11 @@ def create_model(
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
"""
role = role or self.role
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_name

if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return XGBoostModel(
self.model_data,
role,
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,46 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
assert result is not None


def test_deploy_estimator_with_different_instance_types(
mxnet_training_job, sagemaker_session, cpu_instance_type, alternative_cpu_instance_type,
):
def _deploy_estimator_and_assert_instance_type(estimator, instance_type):
# don't use timeout_and_delete_endpoint_by_name because this tests if
# deploy() creates a new endpoint config/endpoint each time
with timeout(minutes=45):
try:
predictor = estimator.deploy(1, instance_type)

model_name = predictor._model_names[0]
config_name = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=predictor.endpoint_name
)["EndpointConfigName"]
config = sagemaker_session.sagemaker_client.describe_endpoint_config(
EndpointConfigName=config_name
)
finally:
predictor.delete_model()
predictor.delete_endpoint()

assert config["ProductionVariants"][0]["InstanceType"] == instance_type

return (model_name, predictor.endpoint_name, config_name)

estimator = MXNet.attach(mxnet_training_job, sagemaker_session)
estimator.base_job_name = "test-mxnet-deploy-twice"

old_model_name, old_endpoint_name, old_config_name = _deploy_estimator_and_assert_instance_type(
estimator, cpu_instance_type
)
new_model_name, new_endpoint_name, new_config_name = _deploy_estimator_and_assert_instance_type(
estimator, alternative_cpu_instance_type
)

assert old_model_name != new_model_name
assert old_endpoint_name != new_endpoint_name
assert old_config_name != new_config_name


def test_deploy_model(
mxnet_training_job,
sagemaker_session,
Expand Down
Loading