Skip to content

Commit 8a400b5

Browse files
committed
address PR comments
1 parent 7a52481 commit 8a400b5

File tree

8 files changed

+35
-45
lines changed

8 files changed

+35
-45
lines changed

src/sagemaker/algorithm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,11 @@ def _is_marketplace(self):
391391
"""Placeholder docstring"""
392392
return "ProductId" in self.algorithm_spec
393393

394+
def _ensure_base_job_name(self):
395+
"""Set ``self.base_job_name`` if it is not set already."""
396+
if self.base_job_name is None:
397+
self.base_job_name = self.algorithm_arn.split("/")[-1]
398+
394399
def _prepare_for_training(self, job_name=None):
395400
# Validate hyperparameters
396401
# an explicit call to set_hyperparameters() will also validate the hyperparameters

src/sagemaker/estimator.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,25 @@ def prepare_workflow_for_training(self, job_name=None):
331331
def _ensure_base_job_name(self):
332332
"""Set ``self.base_job_name`` if it is not set already."""
333333
# honor supplied base_job_name or generate it
334-
if self.base_job_name:
335-
return
336-
337-
if isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
338-
self.base_job_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
339-
else:
334+
if self.base_job_name is None:
340335
self.base_job_name = base_name_from_image(self.train_image())
341336

337+
def _get_or_create_name(self, name=None):
338+
"""Generate a name based on the base job name or training image if needed.
339+
340+
Args:
341+
name (str): User-supplied name. If not specified, a name is generated from
342+
the base job name or training image.
343+
344+
Returns:
345+
str: Either the user-supplied name or a generated name.
346+
"""
347+
if name:
348+
return name
349+
350+
self._ensure_base_job_name()
351+
return name_from_base(self.base_job_name)
352+
342353
def _prepare_for_training(self, job_name=None):
343354
"""Set any values in the estimator that need to be set before training.
344355
@@ -347,11 +358,7 @@ def _prepare_for_training(self, job_name=None):
347358
specified, one is generated, using the base name given to the
348359
constructor if applicable.
349360
"""
350-
if job_name is not None:
351-
self._current_job_name = job_name
352-
else:
353-
self._ensure_base_job_name()
354-
self._current_job_name = name_from_base(self.base_job_name)
361+
self._current_job_name = self._get_or_create_name(job_name)
355362

356363
# if output_path was specified we use it otherwise initialize here.
357364
# For Local Mode with local_code=True we don't need an explicit output_path
@@ -910,9 +917,7 @@ def transformer(
910917
based on the training image name and current timestamp.
911918
"""
912919
tags = tags or self.tags
913-
914-
self._ensure_base_job_name()
915-
model_name = model_name or name_from_base(self.base_job_name)
920+
model_name = self._get_or_create_name(model_name)
916921

917922
if self.latest_training_job is None:
918923
logging.warning(
@@ -2012,9 +2017,7 @@ def transformer(
20122017
"""
20132018
role = role or self.role
20142019
tags = tags or self.tags
2015-
2016-
self._ensure_base_job_name()
2017-
model_name = model_name or name_from_base(self.base_job_name)
2020+
model_name = self._get_or_create_name(model_name)
20182021

20192022
if self.latest_training_job is not None:
20202023
if enable_network_isolation is None:

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import utils
1918
from sagemaker.estimator import Framework
2019
from sagemaker.fw_utils import (
2120
framework_name_from_image,
@@ -219,9 +218,7 @@ def create_model(
219218
if "image" not in kwargs:
220219
kwargs["image"] = image_name or self.image_name
221220

222-
if "name" not in kwargs:
223-
self._ensure_base_job_name()
224-
kwargs["name"] = utils.name_from_base(self.base_job_name)
221+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
225222

226223
return MXNetModel(
227224
self.model_data,

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import utils
1918
from sagemaker.estimator import Framework
2019
from sagemaker.fw_utils import (
2120
framework_name_from_image,
@@ -171,9 +170,7 @@ def create_model(
171170
if "image" not in kwargs:
172171
kwargs["image"] = self.image_name
173172

174-
if "name" not in kwargs:
175-
self._ensure_base_job_name()
176-
kwargs["name"] = utils.name_from_base(self.base_job_name)
173+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
177174

178175
return PyTorchModel(
179176
self.model_data,

src/sagemaker/rl/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import re
1919

20-
from sagemaker import fw_utils, utils
20+
from sagemaker import fw_utils
2121
from sagemaker.estimator import Framework
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
@@ -218,18 +218,18 @@ def create_model(
218218
Raises:
219219
ValueError: If image_name is not specified and framework enum is not valid.
220220
"""
221-
self._ensure_base_job_name()
222221

223222
base_args = dict(
224223
model_data=self.model_data,
225224
role=role or self.role,
226225
image=kwargs.get("image", self.image_name),
227-
name=kwargs.get("name", utils.name_from_base(self.base_job_name)),
228226
container_log_level=self.container_log_level,
229227
sagemaker_session=self.sagemaker_session,
230228
vpc_config=self.get_vpc_config(vpc_config_override),
231229
)
232230

231+
base_args["name"] = self._get_or_create_name(kwargs.get("name"))
232+
233233
if not entry_point and (source_dir or dependencies):
234234
raise AttributeError("Please provide an `entry_point`.")
235235

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import utils
1918
from sagemaker.estimator import Framework
2019
from sagemaker.fw_registry import default_framework_uri
2120
from sagemaker.fw_utils import (
@@ -187,17 +186,14 @@ def create_model(
187186
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
188187
"""
189188
role = role or self.role
189+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
190190

191191
if "image" not in kwargs:
192192
kwargs["image"] = self.image_name
193193

194194
if "enable_network_isolation" not in kwargs:
195195
kwargs["enable_network_isolation"] = self.enable_network_isolation()
196196

197-
if "name" not in kwargs:
198-
self._ensure_base_job_name()
199-
kwargs["name"] = utils.name_from_base(self.base_job_name)
200-
201197
return SKLearnModel(
202198
self.model_data,
203199
role,

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,11 @@ def create_model(
270270
sagemaker.tensorflow.model.TensorFlowModel: A ``TensorFlowModel`` object.
271271
See :class:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
272272
"""
273+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
274+
273275
if "image" not in kwargs:
274276
kwargs["image"] = self.image_name
275277

276-
if "name" not in kwargs:
277-
self._ensure_base_job_name()
278-
kwargs["name"] = utils.name_from_base(self.base_job_name)
279-
280278
if "enable_network_isolation" not in kwargs:
281279
kwargs["enable_network_isolation"] = self.enable_network_isolation()
282280

@@ -446,9 +444,7 @@ def transformer(
446444
based on the training image name and current timestamp.
447445
"""
448446
role = role or self.role
449-
450-
self._ensure_base_job_name()
451-
model_name = model_name or utils.name_from_base(self.base_job_name)
447+
model_name = self._get_or_create_name(model_name)
452448

453449
if self.latest_training_job is None:
454450
logging.warning(

src/sagemaker/xgboost/estimator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import logging
1717

18-
from sagemaker import utils
1918
from sagemaker.estimator import Framework, _TrainingJob
2019
from sagemaker.fw_registry import default_framework_uri
2120
from sagemaker.fw_utils import (
@@ -165,14 +164,11 @@ def create_model(
165164
See :func:`~sagemaker.xgboost.model.XGBoostModel` for full details.
166165
"""
167166
role = role or self.role
167+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
168168

169169
if "image" not in kwargs:
170170
kwargs["image"] = self.image_name
171171

172-
if "name" not in kwargs:
173-
self._ensure_base_job_name()
174-
kwargs["name"] = utils.name_from_base(self.base_job_name)
175-
176172
return XGBoostModel(
177173
self.model_data,
178174
role,

0 commit comments

Comments
 (0)