-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add support to delete model within Predictor and Pipeline class. #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6e09223
f381475
55193a2
eae499c
c055924
d3c614f
edd00a6
c89b8f5
c552c04
c6f6d82
204b197
be7e86c
af76289
f6f46fd
e959e54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,6 +56,7 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ | |
self.deserializer = deserializer | ||
self.content_type = content_type or getattr(serializer, 'content_type', None) | ||
self.accept = accept or getattr(deserializer, 'accept', None) | ||
self._model_names = self._get_model_names() | ||
|
||
def predict(self, data, initial_args=None): | ||
"""Return the inference from the specified endpoint. | ||
|
@@ -109,23 +110,42 @@ def _delete_endpoint_config(self): | |
"""Delete the Amazon SageMaker endpoint configuration | ||
|
||
""" | ||
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint) | ||
endpoint_config_name = endpoint_description['EndpointConfigName'] | ||
self.sagemaker_session.delete_endpoint_config(endpoint_config_name) | ||
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name) | ||
ChoiByungWook marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def delete_endpoint(self, delete_endpoint_config=True): | ||
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor. | ||
"""Delete the Amazon SageMaker endpoint backing this predictor. Also delete the endpoint configuration attached | ||
to it if delete_endpoint_config is True. | ||
|
||
Args: | ||
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint | ||
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True) | ||
delete_endpoint_config (bool, optional): Flag to indicate whether to delete endpoint configuration together | ||
with endpoint. Defaults to True. If True, both endpoint and endpoint configuration will be deleted. If | ||
False, only endpoint will be deleted. | ||
|
||
""" | ||
if delete_endpoint_config: | ||
self._delete_endpoint_config() | ||
|
||
self.sagemaker_session.delete_endpoint(self.endpoint) | ||
|
||
def delete_model(self): | ||
"""Deletes the Amazon SageMaker models backing this predictor. | ||
|
||
""" | ||
for model_name in self._model_names: | ||
self.sagemaker_session.delete_model(model_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the desired behavior if one or some of the requests fail? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! I think we should catch the exception and tell user the deletion is incomplete if one or model delete_model() fail. |
||
|
||
def _get_endpoint_config_desc(self): | ||
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint) | ||
self._endpoint_config_name = endpoint_desc['EndpointConfigName'] | ||
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config( | ||
EndpointConfigName=self._endpoint_config_name) | ||
return endpoint_config | ||
|
||
def _get_model_names(self): | ||
ChoiByungWook marked this conversation as resolved.
Show resolved
Hide resolved
|
||
endpoint_config = self._get_endpoint_config_desc() | ||
production_variants = endpoint_config['ProductionVariants'] | ||
return map(lambda d: d['ModelName'], production_variants) | ||
|
||
|
||
class _CsvSerializer(object): | ||
def __init__(self): | ||
|
Uh oh!
There was an error while loading. Please reload this page.