Skip to content

Commit f0c10a5

Browse files
committed
breaking: create new inference resources during estimator.deploy() or estimator.transformer()
1 parent 5c6deaf commit f0c10a5

File tree

18 files changed

+247
-111
lines changed

18 files changed

+247
-111
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from sagemaker import utils
1819
from sagemaker.estimator import Framework
1920
from sagemaker.fw_utils import (
2021
framework_name_from_image,
@@ -210,7 +211,8 @@ def create_model(
210211
kwargs["image"] = self.image_name
211212

212213
if "name" not in kwargs:
213-
kwargs["name"] = self._current_job_name
214+
self._ensure_base_job_name()
215+
kwargs["name"] = utils.name_from_base(self.base_job_name)
214216

215217
return ChainerModel(
216218
self.model_data,

src/sagemaker/estimator.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
training output (default: None).
141141
base_job_name (str): Prefix for training job name when the
142142
:meth:`~sagemaker.estimator.EstimatorBase.fit` method launches.
143-
If not specified, the estimator generates a default job name,
143+
If not specified, the estimator generates a default job name
144144
based on the training image name and current timestamp.
145145
sagemaker_session (sagemaker.session.Session): Session object which
146146
manages interactions with Amazon SageMaker APIs and any other
@@ -328,6 +328,15 @@ def prepare_workflow_for_training(self, job_name=None):
328328
"""
329329
self._prepare_for_training(job_name=job_name)
330330

331+
def _ensure_base_job_name(self):
332+
# honor supplied base_job_name or generate it
333+
if self.base_job_name:
334+
return
335+
elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
336+
self.base_job_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
337+
else:
338+
self.base_job_name = base_name_from_image(self.train_image())
339+
331340
def _prepare_for_training(self, job_name=None):
332341
"""Set any values in the estimator that need to be set before training.
333342
@@ -339,15 +348,8 @@ def _prepare_for_training(self, job_name=None):
339348
if job_name is not None:
340349
self._current_job_name = job_name
341350
else:
342-
# honor supplied base_job_name or generate it
343-
if self.base_job_name:
344-
base_name = self.base_job_name
345-
elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
346-
base_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
347-
else:
348-
base_name = base_name_from_image(self.train_image())
349-
350-
self._current_job_name = name_from_base(base_name)
351+
self._ensure_base_job_name()
352+
self._current_job_name = name_from_base(self.base_job_name)
351353

352354
# if output_path was specified we use it otherwise initialize here.
353355
# For Local Mode with local_code=True we don't need an explicit output_path
@@ -483,7 +485,7 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
483485
compatibility, boolean values are also accepted and converted to strings.
484486
Only meaningful when wait is True.
485487
job_name (str): Training job name. If not specified, the estimator generates
486-
a default job name, based on the training image name and current timestamp.
488+
a default job name based on the training image name and current timestamp.
487489
experiment_config (dict[str, str]): Experiment management configuration.
488490
Dictionary contains three optional keys,
489491
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
@@ -667,7 +669,8 @@ def deploy(
667669
wait (bool): Whether the call should wait until the deployment of
668670
model completes (default: True).
669671
model_name (str): Name to use for creating an Amazon SageMaker
670-
model. If not specified, the name of the training job is used.
672+
model. If not specified, the estimator generates a default job name
673+
based on the training image name and current timestamp.
671674
kms_key (str): The ARN of the KMS key that is used to encrypt the
672675
data on the storage volume attached to the instance hosting the
673676
endpoint.
@@ -691,8 +694,11 @@ def deploy(
691694
endpoint and obtain inferences.
692695
"""
693696
self._ensure_latest_training_job()
694-
endpoint_name = endpoint_name or self.latest_training_job.name
695-
model_name = model_name or self.latest_training_job.name
697+
self._ensure_base_job_name()
698+
default_name = name_from_base(self.base_job_name)
699+
endpoint_name = endpoint_name or default_name
700+
model_name = model_name or default_name
701+
696702
self.deploy_instance_type = instance_type
697703
if use_compiled_model:
698704
family = "_".join(instance_type.split(".")[:-1])
@@ -898,18 +904,20 @@ def transformer(
898904
If not specified, this setting is taken from the estimator's
899905
current configuration.
900906
model_name (str): Name to use for creating an Amazon SageMaker
901-
model. If not specified, the name of the training job is used.
907+
model. If not specified, the estimator generates a default job name
908+
based on the training image name and current timestamp.
902909
"""
903910
tags = tags or self.tags
904911

912+
self._ensure_base_job_name()
913+
model_name = model_name or name_from_base(self.base_job_name)
914+
905915
if self.latest_training_job is None:
906916
logging.warning(
907917
"No finished training job found associated with this estimator. Please make sure "
908918
"this estimator is only used for building workflow config"
909919
)
910-
model_name = model_name or self._current_job_name
911920
else:
912-
model_name = model_name or self.latest_training_job.name
913921
if enable_network_isolation is None:
914922
enable_network_isolation = self.enable_network_isolation()
915923

@@ -1993,7 +2001,8 @@ def transformer(
19932001
If not specified, this setting is taken from the estimator's
19942002
current configuration.
19952003
model_name (str): Name to use for creating an Amazon SageMaker
1996-
model. If not specified, the name of the training job is used.
2004+
model. If not specified, the estimator generates a default job name
2005+
based on the training image name and current timestamp.
19972006
19982007
Returns:
19992008
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
@@ -2002,6 +2011,9 @@ def transformer(
20022011
role = role or self.role
20032012
tags = tags or self.tags
20042013

2014+
self._ensure_base_job_name()
2015+
model_name = model_name or name_from_base(self.base_job_name)
2016+
20052017
if self.latest_training_job is not None:
20062018
if enable_network_isolation is None:
20072019
enable_network_isolation = self.enable_network_isolation()
@@ -2017,7 +2029,6 @@ def transformer(
20172029
)
20182030
model._create_sagemaker_model(instance_type, tags=tags)
20192031

2020-
model_name = model.name
20212032
transform_env = model.env.copy()
20222033
if env is not None:
20232034
transform_env.update(env)
@@ -2026,7 +2037,6 @@ def transformer(
20262037
"No finished training job found associated with this estimator. Please make sure "
20272038
"this estimator is only used for building workflow config"
20282039
)
2029-
model_name = model_name or self._current_job_name
20302040
transform_env = env or {}
20312041

20322042
return Transformer(

src/sagemaker/mxnet/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from sagemaker import utils
1819
from sagemaker.estimator import Framework
1920
from sagemaker.fw_utils import (
2021
framework_name_from_image,
@@ -219,7 +220,8 @@ def create_model(
219220
kwargs["image"] = image_name or self.image_name
220221

221222
if "name" not in kwargs:
222-
kwargs["name"] = self._current_job_name
223+
self._ensure_base_job_name()
224+
kwargs["name"] = utils.name_from_base(self.base_job_name)
223225

224226
return MXNetModel(
225227
self.model_data,

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from sagemaker import utils
1819
from sagemaker.estimator import Framework
1920
from sagemaker.fw_utils import (
2021
framework_name_from_image,
@@ -171,7 +172,8 @@ def create_model(
171172
kwargs["image"] = self.image_name
172173

173174
if "name" not in kwargs:
174-
kwargs["name"] = self._current_job_name
175+
self._ensure_base_job_name()
176+
kwargs["name"] = utils.name_from_base(self.base_job_name)
175177

176178
return PyTorchModel(
177179
self.model_data,

src/sagemaker/rl/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import logging
1818
import re
1919

20+
from sagemaker import fw_utils, utils
2021
from sagemaker.estimator import Framework
21-
import sagemaker.fw_utils as fw_utils
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
2424
from sagemaker.tensorflow.model import TensorFlowModel
@@ -218,11 +218,13 @@ def create_model(
218218
Raises:
219219
ValueError: If image_name is not specified and framework enum is not valid.
220220
"""
221+
self._ensure_base_job_name()
222+
221223
base_args = dict(
222224
model_data=self.model_data,
223225
role=role or self.role,
224226
image=kwargs.get("image", self.image_name),
225-
name=kwargs.get("name", self._current_job_name),
227+
name=kwargs.get("name", utils.name_from_base(self.base_job_name)),
226228
container_log_level=self.container_log_level,
227229
sagemaker_session=self.sagemaker_session,
228230
vpc_config=self.get_vpc_config(vpc_config_override),

src/sagemaker/sklearn/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from sagemaker import utils
1819
from sagemaker.estimator import Framework
1920
from sagemaker.fw_registry import default_framework_uri
2021
from sagemaker.fw_utils import (
@@ -194,7 +195,8 @@ def create_model(
194195
kwargs["enable_network_isolation"] = self.enable_network_isolation()
195196

196197
if "name" not in kwargs:
197-
kwargs["name"] = self._current_job_name
198+
self._ensure_base_job_name()
199+
kwargs["name"] = utils.name_from_base(self.base_job_name)
198200

199201
return SKLearnModel(
200202
self.model_data,

src/sagemaker/tensorflow/estimator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ def create_model(
274274
kwargs["image"] = self.image_name
275275

276276
if "name" not in kwargs:
277-
kwargs["name"] = self._current_job_name
277+
self._ensure_base_job_name()
278+
kwargs["name"] = utils.name_from_base(self.base_job_name)
278279

279280
if "enable_network_isolation" not in kwargs:
280281
kwargs["enable_network_isolation"] = self.enable_network_isolation()
@@ -441,17 +442,21 @@ def transformer(
441442
If not specified, this setting is taken from the estimator's
442443
current configuration.
443444
model_name (str): Name to use for creating an Amazon SageMaker
444-
model. If not specified, the name of the training job is used.
445+
model. If not specified, the estimator generates a default job name
446+
based on the training image name and current timestamp.
445447
"""
446448
role = role or self.role
447449

450+
self._ensure_base_job_name()
451+
model_name = model_name or utils.name_from_base(self.base_job_name)
452+
448453
if self.latest_training_job is None:
449454
logging.warning(
450455
"No finished training job found associated with this estimator. Please make sure "
451456
"this estimator is only used for building workflow config"
452457
)
453458
return Transformer(
454-
model_name or self._current_job_name,
459+
model_name,
455460
instance_count,
456461
instance_type,
457462
strategy=strategy,

src/sagemaker/xgboost/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717

18+
from sagemaker import utils
1819
from sagemaker.estimator import Framework, _TrainingJob
1920
from sagemaker.fw_registry import default_framework_uri
2021
from sagemaker.fw_utils import (
@@ -169,7 +170,8 @@ def create_model(
169170
kwargs["image"] = self.image_name
170171

171172
if "name" not in kwargs:
172-
kwargs["name"] = self._current_job_name
173+
self._ensure_base_job_name()
174+
kwargs["name"] = utils.name_from_base(self.base_job_name)
173175

174176
return XGBoostModel(
175177
self.model_data,

tests/integ/test_mxnet_train.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,49 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
7070
assert result is not None
7171

7272

73+
def test_deploy_estimator_with_different_instance_types(
74+
mxnet_training_job,
75+
sagemaker_session,
76+
cpu_instance_type,
77+
alternative_cpu_instance_type,
78+
):
79+
def _deploy_estimator_and_assert_instance_type(estimator, instance_type):
80+
# don't use timeout_and_delete_endpoint_by_name because this tests if
81+
# deploy() creates a new endpoint config/endpoint each time
82+
with timeout(minutes=45):
83+
try:
84+
predictor = estimator.deploy(1, instance_type)
85+
86+
model_name = predictor._model_names[0]
87+
endpoint_name = predictor.endpoint
88+
config_name = sagemaker_session.sagemaker_client.describe_endpoint(
89+
EndpointName=endpoint_name
90+
)["EndpointConfigName"]
91+
config = sagemaker_session.sagemaker_client.describe_endpoint_config(
92+
EndpointConfigName=config_name
93+
)
94+
finally:
95+
predictor.delete_model()
96+
predictor.delete_endpoint()
97+
98+
assert config["ProductionVariants"][0]["InstanceType"] == instance_type
99+
100+
return (model_name, endpoint_name, config_name)
101+
102+
estimator = MXNet.attach(mxnet_training_job, sagemaker_session)
103+
104+
old_model_name, old_endpoint_name, old_config_name = _deploy_estimator_and_assert_instance_type(
105+
estimator, cpu_instance_type
106+
)
107+
new_model_name, new_endpoint_name, new_config_name = _deploy_estimator_and_assert_instance_type(
108+
estimator, alternative_cpu_instance_type
109+
)
110+
111+
assert old_model_name != new_model_name
112+
assert old_endpoint_name != new_endpoint_name
113+
assert old_config_name != new_config_name
114+
115+
73116
def test_deploy_model(
74117
mxnet_training_job,
75118
sagemaker_session,

0 commit comments

Comments
 (0)