Skip to content

feature: add wait argument to estimator deploy #842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
return estimator

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
use_compiled_model=False, update_endpoint=False, **kwargs):
use_compiled_model=False, update_endpoint=False, wait=True, **kwargs):
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.

More information:
Expand All @@ -355,6 +355,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
wait (bool): Whether the call should wait until the deployment of model completes (default: True).

**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
Expand All @@ -381,7 +382,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
accelerator_type=accelerator_type,
endpoint_name=endpoint_name,
update_endpoint=update_endpoint,
tags=self.tags)
tags=self.tags,
wait=wait)

@property
def model_data(self):
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def compile(self, target_instance_family, input_shape, output_path, role,
return self

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
update_endpoint=False, tags=None, kms_key=None):
update_endpoint=False, tags=None, kms_key=None, wait=True):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.

Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
Expand Down Expand Up @@ -256,6 +256,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
kms_key (str): The ARN of the KMS key that is used to encrypt the data on the
storage volume attached to the instance hosting the endpoint.
wait (bool): Whether the call should wait until the deployment of this model completes (default: True).

Returns:
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
Expand Down Expand Up @@ -296,7 +297,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant],
tags, kms_key)
tags, kms_key, wait)

if self.predictor_cls:
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def pipeline_container_def(self, instance_type):

return sagemaker.pipeline_container_def(self.models, instance_type)

def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None):
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.

Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
Expand All @@ -86,6 +86,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
endpoint_name (str): The name of the endpoint to create (default: None).
If not specified, a unique endpoint name will be created.
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
wait (bool): Whether the call should wait until the deployment of model completes (default: True).

Returns:
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
Expand All @@ -101,7 +102,8 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags

production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count)
self.endpoint_name = endpoint_name or self.name
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags,
wait=wait)
if self.predictor_cls:
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim

return tuner

def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, **kwargs):
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, wait=True,
**kwargs):
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
``sagemaker.RealTimePredictor`` object.

Expand All @@ -342,6 +343,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
the name of the training job is used.
wait (bool): Whether the call should wait until the deployment of model completes (default: True).
**kwargs: Other arguments needed for deployment. Please refer to the ``create_model()`` method of
the associated estimator to see what other arguments are needed.

Expand All @@ -354,7 +356,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
sagemaker_session=self.estimator.sagemaker_session)
return best_estimator.deploy(initial_instance_count, instance_type,
accelerator_type=accelerator_type,
endpoint_name=endpoint_name, **kwargs)
endpoint_name=endpoint_name, wait=wait, **kwargs)

def stop_tuning_job(self):
"""Stop latest running hyperparameter tuning job.
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ def test_fit_deploy_keep_tags(sagemaker_session):
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
variant,
tags,
None)
None,
True)

sagemaker_session.create_model.assert_called_with(
ANY,
Expand Down
15 changes: 10 additions & 5 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def test_deploy(sagemaker_session, tmpdir):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -182,7 +183,8 @@ def test_deploy_endpoint_name(sagemaker_session, tmpdir):
'InitialInstanceCount': 55,
'VariantName': 'AllTraffic'}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -199,7 +201,8 @@ def test_deploy_tags(sagemaker_session, tmpdir):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
tags,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -217,7 +220,8 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
'VariantName': 'AllTraffic',
'AcceleratorType': ACCELERATOR_TYPE}],
None,
None)
None,
True)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
Expand All @@ -235,7 +239,8 @@ def test_deploy_kms_key(tfo, time, sagemaker_session):
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None,
key)
key,
True)


@patch('sagemaker.session.Session')
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def test_deploy(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None)
None,
wait=True)


@patch('tarfile.open')
Expand All @@ -119,7 +120,8 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
None)
None,
wait=True)


@patch('tarfile.open')
Expand Down Expand Up @@ -178,7 +180,8 @@ def test_deploy_tags(tfo, time, sagemaker_session):
'InstanceType': INSTANCE_TYPE,
'InitialInstanceCount': 1,
'VariantName': 'AllTraffic'}],
tags)
tags,
wait=True)


def test_delete_model_without_deploy(sagemaker_session):
Expand Down