Skip to content

Commit 77f69d0

Browse files
imujjwal96pengk19
authored andcommitted
feature: add wait argument to estimator deploy (aws#842)
1 parent a697178 commit 77f69d0

File tree

7 files changed

+33
-17
lines changed

7 files changed

+33
-17
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
331331
return estimator
332332

333333
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
334-
use_compiled_model=False, update_endpoint=False, **kwargs):
334+
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
335335
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
336336
337337
More information:
@@ -355,6 +355,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
355355
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
356356
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
357357
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
358+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
358359
359360
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
360361
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
@@ -381,7 +382,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
381382
accelerator_type=accelerator_type,
382383
endpoint_name=endpoint_name,
383384
update_endpoint=update_endpoint,
384-
tags=self.tags)
385+
tags=self.tags,
386+
wait=wait)
385387

386388
@property
387389
def model_data(self):

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def compile(self, target_instance_family, input_shape, output_path, role,
228228
return self
229229

230230
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
231-
update_endpoint=False, tags=None, kms_key=None):
231+
update_endpoint=False, tags=None, kms_key=None, wait=True):
232232
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
233233
234234
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -256,6 +256,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
256256
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
257257
kms_key (str): The ARN of the KMS key that is used to encrypt the data on the
258258
storage volume attached to the instance hosting the endpoint.
259+
wait (bool): Whether the call should wait until the deployment of this model completes (default: True).
259260
260261
Returns:
261262
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
@@ -296,7 +297,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
296297
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
297298
else:
298299
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant],
299-
tags, kms_key)
300+
tags, kms_key, wait)
300301

301302
if self.predictor_cls:
302303
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

src/sagemaker/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def pipeline_container_def(self, instance_type):
6767

6868
return sagemaker.pipeline_container_def(self.models, instance_type)
6969

70-
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None):
70+
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True):
7171
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
7272
7373
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -86,6 +86,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
8686
endpoint_name (str): The name of the endpoint to create (default: None).
8787
If not specified, a unique endpoint name will be created.
8888
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
89+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
8990
9091
Returns:
9192
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
@@ -101,7 +102,8 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
101102

102103
production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count)
103104
self.endpoint_name = endpoint_name or self.name
104-
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
105+
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags,
106+
wait=wait)
105107
if self.predictor_cls:
106108
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
107109

src/sagemaker/tuner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
326326

327327
return tuner
328328

329-
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, **kwargs):
329+
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, wait=True,
330+
**kwargs):
330331
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
331332
``sagemaker.RealTimePredictor`` object.
332333
@@ -342,6 +343,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
342343
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
343344
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
344345
the name of the training job is used.
346+
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
345347
**kwargs: Other arguments needed for deployment. Please refer to the ``create_model()`` method of
346348
the associated estimator to see what other arguments are needed.
347349
@@ -354,7 +356,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
354356
sagemaker_session=self.estimator.sagemaker_session)
355357
return best_estimator.deploy(initial_instance_count, instance_type,
356358
accelerator_type=accelerator_type,
357-
endpoint_name=endpoint_name, **kwargs)
359+
endpoint_name=endpoint_name, wait=wait, **kwargs)
358360

359361
def stop_tuning_job(self):
360362
"""Stop latest running hyperparameter tuning job.

tests/unit/test_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,8 @@ def test_fit_deploy_keep_tags(sagemaker_session):
922922
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
923923
variant,
924924
tags,
925-
None)
925+
None,
926+
True)
926927

927928
sagemaker_session.create_model.assert_called_with(
928929
ANY,

tests/unit/test_model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def test_deploy(sagemaker_session, tmpdir):
166166
'InitialInstanceCount': 1,
167167
'VariantName': 'AllTraffic'}],
168168
None,
169-
None)
169+
None,
170+
True)
170171

171172

172173
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -182,7 +183,8 @@ def test_deploy_endpoint_name(sagemaker_session, tmpdir):
182183
'InitialInstanceCount': 55,
183184
'VariantName': 'AllTraffic'}],
184185
None,
185-
None)
186+
None,
187+
True)
186188

187189

188190
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -199,7 +201,8 @@ def test_deploy_tags(sagemaker_session, tmpdir):
199201
'InitialInstanceCount': 1,
200202
'VariantName': 'AllTraffic'}],
201203
tags,
202-
None)
204+
None,
205+
True)
203206

204207

205208
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -217,7 +220,8 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
217220
'VariantName': 'AllTraffic',
218221
'AcceleratorType': ACCELERATOR_TYPE}],
219222
None,
220-
None)
223+
None,
224+
True)
221225

222226

223227
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -235,7 +239,8 @@ def test_deploy_kms_key(tfo, time, sagemaker_session):
235239
'InitialInstanceCount': 1,
236240
'VariantName': 'AllTraffic'}],
237241
None,
238-
key)
242+
key,
243+
True)
239244

240245

241246
@patch('sagemaker.session.Session')

tests/unit/test_pipeline_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def test_deploy(tfo, time, sagemaker_session):
102102
'InstanceType': INSTANCE_TYPE,
103103
'InitialInstanceCount': 1,
104104
'VariantName': 'AllTraffic'}],
105-
None)
105+
None,
106+
wait=True)
106107

107108

108109
@patch('tarfile.open')
@@ -119,7 +120,8 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
119120
'InstanceType': INSTANCE_TYPE,
120121
'InitialInstanceCount': 1,
121122
'VariantName': 'AllTraffic'}],
122-
None)
123+
None,
124+
wait=True)
123125

124126

125127
@patch('tarfile.open')
@@ -178,7 +180,8 @@ def test_deploy_tags(tfo, time, sagemaker_session):
178180
'InstanceType': INSTANCE_TYPE,
179181
'InitialInstanceCount': 1,
180182
'VariantName': 'AllTraffic'}],
181-
tags)
183+
tags,
184+
wait=True)
182185

183186

184187
def test_delete_model_without_deploy(sagemaker_session):

0 commit comments

Comments
 (0)