Skip to content

Commit d388519

Browse files
authored
breaking: create new inference resources during estimator.deploy() or estimator.transformer() (#1639)
1 parent cde5500 commit d388519

File tree

20 files changed

+256
-130
lines changed

20 files changed

+256
-130
lines changed

src/sagemaker/algorithm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,11 @@ def _is_marketplace(self):
391391
"""Placeholder docstring"""
392392
return "ProductId" in self.algorithm_spec
393393

394+
def _ensure_base_job_name(self):
395+
"""Set ``self.base_job_name`` if it is not set already."""
396+
if self.base_job_name is None:
397+
self.base_job_name = self.algorithm_arn.split("/")[-1]
398+
394399
def _prepare_for_training(self, job_name=None):
395400
# Validate hyperparameters
396401
# an explicit call to set_hyperparameters() will also validate the hyperparameters

src/sagemaker/chainer/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,11 @@ def create_model(
206206
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
207207
object. See :func:`~sagemaker.chainer.model.ChainerModel` for full details.
208208
"""
209+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
210+
209211
if "image" not in kwargs:
210212
kwargs["image"] = self.image_name
211213

212-
if "name" not in kwargs:
213-
kwargs["name"] = self._current_job_name
214-
215214
return ChainerModel(
216215
self.model_data,
217216
role or self.role,

src/sagemaker/estimator.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
training output (default: None).
141141
base_job_name (str): Prefix for training job name when the
142142
:meth:`~sagemaker.estimator.EstimatorBase.fit` method launches.
143-
If not specified, the estimator generates a default job name,
143+
If not specified, the estimator generates a default job name
144144
based on the training image name and current timestamp.
145145
sagemaker_session (sagemaker.session.Session): Session object which
146146
manages interactions with Amazon SageMaker APIs and any other
@@ -328,6 +328,28 @@ def prepare_workflow_for_training(self, job_name=None):
328328
"""
329329
self._prepare_for_training(job_name=job_name)
330330

331+
def _ensure_base_job_name(self):
332+
"""Set ``self.base_job_name`` if it is not set already."""
333+
# honor supplied base_job_name or generate it
334+
if self.base_job_name is None:
335+
self.base_job_name = base_name_from_image(self.train_image())
336+
337+
def _get_or_create_name(self, name=None):
338+
"""Generate a name based on the base job name or training image if needed.
339+
340+
Args:
341+
name (str): User-supplied name. If not specified, a name is generated from
342+
the base job name or training image.
343+
344+
Returns:
345+
str: Either the user-supplied name or a generated name.
346+
"""
347+
if name:
348+
return name
349+
350+
self._ensure_base_job_name()
351+
return name_from_base(self.base_job_name)
352+
331353
def _prepare_for_training(self, job_name=None):
332354
"""Set any values in the estimator that need to be set before training.
333355
@@ -336,18 +358,7 @@ def _prepare_for_training(self, job_name=None):
336358
specified, one is generated, using the base name given to the
337359
constructor if applicable.
338360
"""
339-
if job_name is not None:
340-
self._current_job_name = job_name
341-
else:
342-
# honor supplied base_job_name or generate it
343-
if self.base_job_name:
344-
base_name = self.base_job_name
345-
elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
346-
base_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
347-
else:
348-
base_name = base_name_from_image(self.train_image())
349-
350-
self._current_job_name = name_from_base(base_name)
361+
self._current_job_name = self._get_or_create_name(job_name)
351362

352363
# if output_path was specified we use it otherwise initialize here.
353364
# For Local Mode with local_code=True we don't need an explicit output_path
@@ -483,7 +494,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
483494
compatibility, boolean values are also accepted and converted to strings.
484495
Only meaningful when wait is True.
485496
job_name (str): Training job name. If not specified, the estimator generates
486-
a default job name, based on the training image name and current timestamp.
497+
a default job name based on the training image name and current timestamp.
487498
experiment_config (dict[str, str]): Experiment management configuration.
488499
Dictionary contains three optional keys,
489500
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
@@ -667,7 +678,8 @@ def deploy(
667678
wait (bool): Whether the call should wait until the deployment of
668679
model completes (default: True).
669680
model_name (str): Name to use for creating an Amazon SageMaker
670-
model. If not specified, the name of the training job is used.
681+
model. If not specified, the estimator generates a default job name
682+
based on the training image name and current timestamp.
671683
kms_key (str): The ARN of the KMS key that is used to encrypt the
672684
data on the storage volume attached to the instance hosting the
673685
endpoint.
@@ -691,8 +703,11 @@ def deploy(
691703
endpoint and obtain inferences.
692704
"""
693705
self._ensure_latest_training_job()
694-
endpoint_name = endpoint_name or self.latest_training_job.name
695-
model_name = model_name or self.latest_training_job.name
706+
self._ensure_base_job_name()
707+
default_name = name_from_base(self.base_job_name)
708+
endpoint_name = endpoint_name or default_name
709+
model_name = model_name or default_name
710+
696711
self.deploy_instance_type = instance_type
697712
if use_compiled_model:
698713
family = "_".join(instance_type.split(".")[:-1])
@@ -889,18 +904,18 @@ def transformer(
889904
If not specified, this setting is taken from the estimator's
890905
current configuration.
891906
model_name (str): Name to use for creating an Amazon SageMaker
892-
model. If not specified, the name of the training job is used.
907+
model. If not specified, the estimator generates a default job name
908+
based on the training image name and current timestamp.
893909
"""
894910
tags = tags or self.tags
911+
model_name = self._get_or_create_name(model_name)
895912

896913
if self.latest_training_job is None:
897914
logging.warning(
898915
"No finished training job found associated with this estimator. Please make sure "
899916
"this estimator is only used for building workflow config"
900917
)
901-
model_name = model_name or self._current_job_name
902918
else:
903-
model_name = model_name or self.latest_training_job.name
904919
if enable_network_isolation is None:
905920
enable_network_isolation = self.enable_network_isolation()
906921

@@ -1984,14 +1999,16 @@ def transformer(
19841999
If not specified, this setting is taken from the estimator's
19852000
current configuration.
19862001
model_name (str): Name to use for creating an Amazon SageMaker
1987-
model. If not specified, the name of the training job is used.
2002+
model. If not specified, the estimator generates a default job name
2003+
based on the training image name and current timestamp.
19882004
19892005
Returns:
19902006
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
19912007
SageMaker Batch Transform job.
19922008
"""
19932009
role = role or self.role
19942010
tags = tags or self.tags
2011+
model_name = self._get_or_create_name(model_name)
19952012

19962013
if self.latest_training_job is not None:
19972014
if enable_network_isolation is None:
@@ -2008,7 +2025,6 @@ def transformer(
20082025
)
20092026
model._create_sagemaker_model(instance_type, tags=tags)
20102027

2011-
model_name = model.name
20122028
transform_env = model.env.copy()
20132029
if env is not None:
20142030
transform_env.update(env)
@@ -2017,7 +2033,6 @@ def transformer(
20172033
"No finished training job found associated with this estimator. Please make sure "
20182034
"this estimator is only used for building workflow config"
20192035
)
2020-
model_name = model_name or self._current_job_name
20212036
transform_env = env or {}
20222037

20232038
return Transformer(

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ def create_model(
218218
if "image" not in kwargs:
219219
kwargs["image"] = image_name or self.image_name
220220

221-
if "name" not in kwargs:
222-
kwargs["name"] = self._current_job_name
221+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
223222

224223
return MXNetModel(
225224
self.model_data,

src/sagemaker/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def _get_model_names(self):
303303
EndpointConfigName=self._endpoint_config_name
304304
)
305305
production_variants = endpoint_config["ProductionVariants"]
306-
return map(lambda d: d["ModelName"], production_variants)
306+
return [d["ModelName"] for d in production_variants]
307307

308308

309309
class _CsvSerializer(object):

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ def create_model(
170170
if "image" not in kwargs:
171171
kwargs["image"] = self.image_name
172172

173-
if "name" not in kwargs:
174-
kwargs["name"] = self._current_job_name
173+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
175174

176175
return PyTorchModel(
177176
self.model_data,

src/sagemaker/rl/estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import logging
1818
import re
1919

20+
from sagemaker import fw_utils
2021
from sagemaker.estimator import Framework
21-
import sagemaker.fw_utils as fw_utils
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
2424
from sagemaker.tensorflow.model import TensorFlowModel
@@ -222,12 +222,13 @@ def create_model(
222222
model_data=self.model_data,
223223
role=role or self.role,
224224
image=kwargs.get("image", self.image_name),
225-
name=kwargs.get("name", self._current_job_name),
226225
container_log_level=self.container_log_level,
227226
sagemaker_session=self.sagemaker_session,
228227
vpc_config=self.get_vpc_config(vpc_config_override),
229228
)
230229

230+
base_args["name"] = self._get_or_create_name(kwargs.get("name"))
231+
231232
if not entry_point and (source_dir or dependencies):
232233
raise AttributeError("Please provide an `entry_point`.")
233234

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,14 @@ def create_model(
186186
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
187187
"""
188188
role = role or self.role
189+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
189190

190191
if "image" not in kwargs:
191192
kwargs["image"] = self.image_name
192193

193194
if "enable_network_isolation" not in kwargs:
194195
kwargs["enable_network_isolation"] = self.enable_network_isolation()
195196

196-
if "name" not in kwargs:
197-
kwargs["name"] = self._current_job_name
198-
199197
return SKLearnModel(
200198
self.model_data,
201199
role,

src/sagemaker/tensorflow/estimator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,11 @@ def create_model(
269269
sagemaker.tensorflow.model.TensorFlowModel: A ``TensorFlowModel`` object.
270270
See :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
271271
"""
272+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
273+
272274
if "image" not in kwargs:
273275
kwargs["image"] = self.image_name
274276

275-
if "name" not in kwargs:
276-
kwargs["name"] = self._current_job_name
277-
278277
if "enable_network_isolation" not in kwargs:
279278
kwargs["enable_network_isolation"] = self.enable_network_isolation()
280279

@@ -440,17 +439,19 @@ def transformer(
440439
If not specified, this setting is taken from the estimator's
441440
current configuration.
442441
model_name (str): Name to use for creating an Amazon SageMaker
443-
model. If not specified, the name of the training job is used.
442+
model. If not specified, the estimator generates a default job name
443+
based on the training image name and current timestamp.
444444
"""
445445
role = role or self.role
446+
model_name = self._get_or_create_name(model_name)
446447

447448
if self.latest_training_job is None:
448449
logging.warning(
449450
"No finished training job found associated with this estimator. Please make sure "
450451
"this estimator is only used for building workflow config"
451452
)
452453
return Transformer(
453-
model_name or self._current_job_name,
454+
model_name,
454455
instance_count,
455456
instance_type,
456457
strategy=strategy,

src/sagemaker/xgboost/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,11 @@ def create_model(
164164
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
165165
"""
166166
role = role or self.role
167+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
167168

168169
if "image" not in kwargs:
169170
kwargs["image"] = self.image_name
170171

171-
if "name" not in kwargs:
172-
kwargs["name"] = self._current_job_name
173-
174172
return XGBoostModel(
175173
self.model_data,
176174
role,

tests/integ/test_mxnet_train.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,46 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
7070
assert result is not None
7171

7272

73+
def test_deploy_estimator_with_different_instance_types(
74+
mxnet_training_job, sagemaker_session, cpu_instance_type, alternative_cpu_instance_type,
75+
):
76+
def _deploy_estimator_and_assert_instance_type(estimator, instance_type):
77+
# don't use timeout_and_delete_endpoint_by_name because this tests if
78+
# deploy() creates a new endpoint config/endpoint each time
79+
with timeout(minutes=45):
80+
try:
81+
predictor = estimator.deploy(1, instance_type)
82+
83+
model_name = predictor._model_names[0]
84+
config_name = sagemaker_session.sagemaker_client.describe_endpoint(
85+
EndpointName=predictor.endpoint_name
86+
)["EndpointConfigName"]
87+
config = sagemaker_session.sagemaker_client.describe_endpoint_config(
88+
EndpointConfigName=config_name
89+
)
90+
finally:
91+
predictor.delete_model()
92+
predictor.delete_endpoint()
93+
94+
assert config["ProductionVariants"][0]["InstanceType"] == instance_type
95+
96+
return (model_name, predictor.endpoint_name, config_name)
97+
98+
estimator = MXNet.attach(mxnet_training_job, sagemaker_session)
99+
estimator.base_job_name = "test-mxnet-deploy-twice"
100+
101+
old_model_name, old_endpoint_name, old_config_name = _deploy_estimator_and_assert_instance_type(
102+
estimator, cpu_instance_type
103+
)
104+
new_model_name, new_endpoint_name, new_config_name = _deploy_estimator_and_assert_instance_type(
105+
estimator, alternative_cpu_instance_type
106+
)
107+
108+
assert old_model_name != new_model_name
109+
assert old_endpoint_name != new_endpoint_name
110+
assert old_config_name != new_config_name
111+
112+
73113
def test_deploy_model(
74114
mxnet_training_job,
75115
sagemaker_session,

0 commit comments

Comments
 (0)