Skip to content

Commit c4d0ece

Browse files
committed
feature: add wait argument to estimator deploy
1 parent 003f9c5 commit c4d0ece

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331331
return estimator
332332

333333
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
334-
use_compiled_model=False, update_endpoint=False, **kwargs):
334+
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
335335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336336
337337
More information:
@@ -355,6 +355,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
355355
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
356356
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
357357
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
358+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
358359
359360
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
360361
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -381,7 +382,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
381382
accelerator_type=accelerator_type,
382383
endpoint_name=endpoint_name,
383384
update_endpoint=update_endpoint,
384-
tags=self.tags)
385+
tags=self.tags,
386+
wait=wait)
385387

386388
@property
387389
def model_data(self):

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def compile(self, target_instance_family, input_shape, output_path, role,
228228
return self
229229

230230
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
231-
update_endpoint=False, tags=None, kms_key=None):
231+
update_endpoint=False, tags=None, kms_key=None, wait=True):
232232
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
233233
234234
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -256,6 +256,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
256256
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
257257
kms_key (str): The ARN of the KMS key that is used to encrypt the data on the
258258
storage volume attached to the instance hosting the endpoint.
259+
wait (bool): Whether the call should wait until the deployment of this model completes (default: True).
259260
260261
Returns:
261262
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
@@ -296,7 +297,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
296297
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
297298
else:
298299
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant],
299-
tags, kms_key)
300+
tags, kms_key, wait)
300301

301302
if self.predictor_cls:
302303
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

tests/unit/test_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,8 @@ def test_fit_deploy_keep_tags(sagemaker_session):
922922
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
923923
variant,
924924
tags,
925-
None)
925+
None,
926+
True)
926927

927928
sagemaker_session.create_model.assert_called_with(
928929
ANY,

tests/unit/test_model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def test_deploy(sagemaker_session, tmpdir):
166166
'InitialInstanceCount': 1,
167167
'VariantName': 'AllTraffic'}],
168168
None,
169-
None)
169+
None,
170+
True)
170171

171172

172173
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -182,7 +183,8 @@ def test_deploy_endpoint_name(sagemaker_session, tmpdir):
182183
'InitialInstanceCount': 55,
183184
'VariantName': 'AllTraffic'}],
184185
None,
185-
None)
186+
None,
187+
True)
186188

187189

188190
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -199,7 +201,8 @@ def test_deploy_tags(sagemaker_session, tmpdir):
199201
'InitialInstanceCount': 1,
200202
'VariantName': 'AllTraffic'}],
201203
tags,
202-
None)
204+
None,
205+
True)
203206

204207

205208
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -217,7 +220,8 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
217220
'VariantName': 'AllTraffic',
218221
'AcceleratorType': ACCELERATOR_TYPE}],
219222
None,
220-
None)
223+
None,
224+
True)
221225

222226

223227
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -235,7 +239,8 @@ def test_deploy_kms_key(tfo, time, sagemaker_session):
235239
'InitialInstanceCount': 1,
236240
'VariantName': 'AllTraffic'}],
237241
None,
238-
key)
242+
key,
243+
True)
239244

240245

241246
@patch('sagemaker.session.Session')

0 commit comments

Comments
 (0)