Skip to content

Commit 42b3906

Browse files
committed
feature: allow custom model name during deploy
1 parent 686569e commit 42b3906

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331331
return estimator
332332

333333
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
334-
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
334+
use_compiled_model=False, update_endpoint=False, wait=True, model_name=None, **kwargs):
335335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336336
337337
More information:
@@ -351,11 +351,13 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
351351
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
352352
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
353353
corresponding to the previous EndpointConfig. Default: False
354+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
355+
model_name (str): Name to use for creating an Amazon SageMaker model. If not specified, the name of
356+
the training job is used.
354357
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example:
355358
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
356359
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
357360
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
358-
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
359361
360362
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
361363
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -367,6 +369,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
367369
"""
368370
self._ensure_latest_training_job()
369371
endpoint_name = endpoint_name or self.latest_training_job.name
372+
model_name = model_name or self.latest_training_job.name
370373
self.deploy_instance_type = instance_type
371374
if use_compiled_model:
372375
family = '_'.join(instance_type.split('.')[:-1])
@@ -376,6 +379,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
376379
model = self._compiled_models[family]
377380
else:
378381
model = self.create_model(**kwargs)
382+
model.name = model_name
379383
return model.deploy(
380384
instance_type=instance_type,
381385
initial_instance_count=initial_instance_count,

tests/integ/test_tf_script_mode.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,21 @@ def test_mnist_async(sagemaker_session):
136136
training_job_name = estimator.latest_training_job.name
137137
time.sleep(20)
138138
endpoint_name = training_job_name
139+
model_name = 'model-name-1'
139140
_assert_training_job_tags_match(sagemaker_session.sagemaker_client,
140141
estimator.latest_training_job.name, TAGS)
141142
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
142143
estimator = TensorFlow.attach(training_job_name=training_job_name,
143144
sagemaker_session=sagemaker_session)
144145
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge',
145-
endpoint_name=endpoint_name)
146+
endpoint_name=endpoint_name, model_name=model_name)
146147

147148
result = predictor.predict(np.zeros(784))
148149
print('predict result: {}'.format(result))
149150
_assert_endpoint_tags_match(sagemaker_session.sagemaker_client, predictor.endpoint, TAGS)
150151
_assert_model_tags_match(sagemaker_session.sagemaker_client,
151152
estimator.latest_training_job.name, TAGS)
153+
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
152154

153155

154156
def test_deploy_with_input_handlers(sagemaker_session, instance_type):
@@ -208,3 +210,8 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
208210
training_job_description = sagemaker_client.describe_training_job(
209211
TrainingJobName=training_job_name)
210212
_assert_tags_match(sagemaker_client, training_job_description['TrainingJobArn'], tags)
213+
214+
215+
def _assert_model_name_match(sagemaker_client, endpoint_config_name, model_name):
216+
endpoint_config_description = sagemaker_client.describe_endpoint(EndpointConfigName=endpoint_config_name)
217+
assert model_name == endpoint_config_description['ProductionVariants'][0]['ModelName']

tests/unit/test_estimator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,31 @@ def test_deploy_with_update_endpoint(sagemaker_session):
11401140
sagemaker_session.create_endpoint.assert_not_called()
11411141

11421142

1143+
def test_deploy_with_model_name(sagemaker_session):
1144+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1145+
sagemaker_session=sagemaker_session)
1146+
estimator.set_hyperparameters(**HYPERPARAMS)
1147+
estimator.fit({'train': 's3://bucket/training-prefix'})
1148+
model_name = 'model-name'
1149+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, model_name=model_name)
1150+
1151+
sagemaker_session.create_model.assert_called_once()
1152+
args, kwargs = sagemaker_session.create_model.call_args
1153+
assert args[0] == model_name
1154+
1155+
1156+
def test_deploy_with_no_model_name(sagemaker_session):
1157+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
1158+
sagemaker_session=sagemaker_session)
1159+
estimator.set_hyperparameters(**HYPERPARAMS)
1160+
estimator.fit({'train': 's3://bucket/training-prefix'})
1161+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
1162+
1163+
sagemaker_session.create_model.assert_called_once()
1164+
args, kwargs = sagemaker_session.create_model.call_args
1165+
assert args[0].startswith(IMAGE_NAME)
1166+
1167+
11431168
@patch('sagemaker.estimator.LocalSession')
11441169
@patch('sagemaker.estimator.Session')
11451170
def test_local_mode(session_class, local_session_class):

tests/unit/test_tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def test_deploy_default(tuner):
566566

567567
tuner.estimator.sagemaker_session.create_model.assert_called_once()
568568
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
569-
assert args[0].startswith(IMAGE_NAME)
569+
570+
assert args[0] == 'neo'
570571
assert args[1] == ROLE
571572
assert args[2]['Image'] == IMAGE_NAME
572573
assert args[2]['ModelDataUrl'] == MODEL_DATA

0 commit comments

Comments
 (0)