Skip to content

Commit 3b5e661

Browse files
authored
change: allow specifying model name when creating a Transformer from an Estimator (#1398)
1 parent f76f8a8 commit 3b5e661

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def transformer(
826826
volume_kms_key=None,
827827
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
828828
enable_network_isolation=None,
829+
model_name=None,
829830
):
830831
"""Return a ``Transformer`` that uses a SageMaker Model based on the
831832
training job. It reuses the SageMaker Session and base job name used by
@@ -876,6 +877,8 @@ def transformer(
876877
user entry script for inference. Also known as Internet-free mode.
877878
If not specified, this setting is taken from the estimator's
878879
current configuration.
880+
model_name (str): Name to use for creating an Amazon SageMaker
881+
model. If not specified, the name of the training job is used.
879882
"""
880883
tags = tags or self.tags
881884

@@ -884,9 +887,9 @@ def transformer(
884887
"No finished training job found associated with this estimator. Please make sure "
885888
"this estimator is only used for building workflow config"
886889
)
887-
model_name = self._current_job_name
890+
model_name = model_name or self._current_job_name
888891
else:
889-
model_name = self.latest_training_job.name
892+
model_name = model_name or self.latest_training_job.name
890893
if enable_network_isolation is None:
891894
enable_network_isolation = self.enable_network_isolation()
892895

@@ -1897,6 +1900,7 @@ def transformer(
18971900
entry_point=None,
18981901
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
18991902
enable_network_isolation=None,
1903+
model_name=None,
19001904
):
19011905
"""Return a ``Transformer`` that uses a SageMaker Model based on the
19021906
training job. It reuses the SageMaker Session and base job name used by
@@ -1953,6 +1957,8 @@ def transformer(
19531957
user entry script for inference. Also known as Internet-free mode.
19541958
If not specified, this setting is taken from the estimator's
19551959
current configuration.
1960+
model_name (str): Name to use for creating an Amazon SageMaker
1961+
model. If not specified, the name of the training job is used.
19561962
19571963
Returns:
19581964
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
@@ -1972,6 +1978,7 @@ def transformer(
19721978
vpc_config_override=vpc_config_override,
19731979
model_kms_key=self.output_kms_key,
19741980
enable_network_isolation=enable_network_isolation,
1981+
name=model_name,
19751982
)
19761983
model._create_sagemaker_model(instance_type, tags=tags)
19771984

@@ -1984,7 +1991,7 @@ def transformer(
19841991
"No finished training job found associated with this estimator. Please make sure "
19851992
"this estimator is only used for building workflow config"
19861993
)
1987-
model_name = self._current_job_name
1994+
model_name = model_name or self._current_job_name
19881995
transform_env = env or {}
19891996

19901997
return Transformer(

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ def transformer(
780780
entry_point=None,
781781
vpc_config_override=VPC_CONFIG_DEFAULT,
782782
enable_network_isolation=None,
783+
model_name=None,
783784
):
784785
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It
785786
reuses the SageMaker Session and base job name used by the Estimator.
@@ -837,6 +838,8 @@ def transformer(
837838
user entry script for inference. Also known as Internet-free mode.
838839
If not specified, this setting is taken from the estimator's
839840
current configuration.
841+
model_name (str): Name to use for creating an Amazon SageMaker
842+
model. If not specified, the name of the training job is used.
840843
"""
841844
role = role or self.role
842845

@@ -846,7 +849,7 @@ def transformer(
846849
"this estimator is only used for building workflow config"
847850
)
848851
return Transformer(
849-
self._current_job_name,
852+
model_name or self._current_job_name,
850853
instance_count,
851854
instance_type,
852855
strategy=strategy,
@@ -873,6 +876,7 @@ def transformer(
873876
endpoint_type=endpoint_type,
874877
entry_point=entry_point,
875878
enable_network_isolation=enable_network_isolation,
879+
name=model_name,
876880
)
877881

878882
return model.transformer(

tests/unit/test_estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13741374
env = {"FOO": "BAR"}
13751375
new_role = "dummy-model-role"
13761376
new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]}
1377+
model_name = "model-name"
13771378

13781379
transformer = fw.transformer(
13791380
INSTANCE_COUNT,
@@ -1392,10 +1393,11 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
13921393
model_server_workers=1,
13931394
vpc_config_override=new_vpc_config,
13941395
enable_network_isolation=True,
1396+
model_name=model_name,
13951397
)
13961398

13971399
sagemaker_session.create_model.assert_called_with(
1398-
MODEL_IMAGE,
1400+
model_name,
13991401
new_role,
14001402
MODEL_CONTAINER_DEF,
14011403
vpc_config=new_vpc_config,
@@ -1413,6 +1415,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
14131415
assert transformer.base_transform_job_name == base_name
14141416
assert transformer.tags == TAGS
14151417
assert transformer.volume_kms_key == kms_key
1418+
assert transformer.model_name == model_name
14161419

14171420

14181421
def test_ensure_latest_training_job(sagemaker_session):
@@ -1492,6 +1495,7 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem
14921495
max_payload = 6
14931496
env = {"FOO": "BAR"}
14941497
new_vpc_config = {"Subnets": ["x"], "SecurityGroupIds": ["y"]}
1498+
model_name = "model-name"
14951499

14961500
transformer = estimator.transformer(
14971501
INSTANCE_COUNT,
@@ -1508,6 +1512,7 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem
15081512
role=ROLE,
15091513
vpc_config_override=new_vpc_config,
15101514
enable_network_isolation=True,
1515+
model_name=model_name,
15111516
)
15121517

15131518
create_model.assert_called_with(
@@ -1524,6 +1529,7 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem
15241529
assert transformer.env == env
15251530
assert transformer.base_transform_job_name == base_name
15261531
assert transformer.tags == TAGS
1532+
assert transformer.model_name == model_name
15271533

15281534

15291535
# _TrainingJob 'utils'

tests/unit/test_tf_estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def test_transformer_creation_with_optional_args(create_model, sagemaker_session
364364
new_role = "role"
365365
model_server_workers = 2
366366
vpc_config = {"Subnets": ["1234"], "SecurityGroupIds": ["5678"]}
367+
model_name = "model-name"
367368

368369
tf.transformer(
369370
INSTANCE_COUNT,
@@ -384,6 +385,7 @@ def test_transformer_creation_with_optional_args(create_model, sagemaker_session
384385
entry_point=SERVING_SCRIPT_FILE,
385386
vpc_config_override=vpc_config,
386387
enable_network_isolation=True,
388+
model_name=model_name,
387389
)
388390

389391
create_model.assert_called_with(
@@ -393,6 +395,7 @@ def test_transformer_creation_with_optional_args(create_model, sagemaker_session
393395
endpoint_type="tensorflow-serving",
394396
entry_point=SERVING_SCRIPT_FILE,
395397
enable_network_isolation=True,
398+
name=model_name,
396399
)
397400
model.transformer.assert_called_with(
398401
INSTANCE_COUNT,
@@ -432,6 +435,7 @@ def test_transformer_creation_without_optional_args(create_model, sagemaker_sess
432435
vpc_config_override="VPC_CONFIG_DEFAULT",
433436
entry_point=None,
434437
enable_network_isolation=False,
438+
name=None,
435439
)
436440
model.transformer.assert_called_with(
437441
INSTANCE_COUNT,

0 commit comments

Comments
 (0)