Skip to content

Commit 5740939

Browse files
chuyang-dengChoiByungWook
authored andcommitted
Add support to delete model within Predictor and Pipeline class. (#647)
1 parent a934a1a commit 5740939

30 files changed

+375
-9
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ CHANGELOG
66
==========
77

88
* doc-fix: Remove incorrect parameter for EI TFS Python README
9+
* feature: ``Predictor``: delete SageMaker model
10+
* feature: ``Pipeline``: delete SageMaker model
911

1012
1.18.3.post1
1113
============

README.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ Here is an end to end example of how to use a SageMaker Estimator:
192192
# Tears down the SageMaker endpoint and endpoint configuration
193193
mxnet_predictor.delete_endpoint()
194194
195+
# Deletes the SageMaker model
196+
mxnet_predictor.delete_model()
195197
196198
The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below.
197199

@@ -230,6 +232,9 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
230232
# Tears down the SageMaker endpoint and endpoint configuration
231233
mxnet_predictor.delete_endpoint()
232234
235+
# Deletes the SageMaker model
236+
mxnet_predictor.delete_model()
237+
233238
Training Metrics
234239
~~~~~~~~~~~~~~~~
235240
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
@@ -284,6 +289,9 @@ We can take the example in `Using Estimators <#using-estimators>`__ , and use e
284289
# Tears down the endpoint container and deletes the corresponding endpoint configuration
285290
mxnet_predictor.delete_endpoint()
286291
292+
# Deletes the model
293+
mxnet_predictor.delete_model()
294+
287295
288296
If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor.
289297
The correct session is generated when you call ``model.deploy()``.
@@ -307,6 +315,9 @@ Here is an end-to-end example:
307315
# Tear down the endpoint container and delete the corresponding endpoint configuration
308316
predictor.delete_endpoint()
309317
318+
# Deletes the model
319+
predictor.delete_model()
320+
310321
311322
If you don't want to deploy your model locally, you can also choose to perform a Local Batch Transform Job. This is
312323
useful if you want to test your container before creating a Sagemaker Batch Transform Job. Note that the performance

src/sagemaker/pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,14 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
103103
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
104104
if self.predictor_cls:
105105
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
106+
107+
def delete_model(self):
108+
"""Delete the SageMaker model backing this pipeline model. This does not delete the list of SageMaker models used
109+
in multiple containers to build the inference pipeline.
110+
111+
"""
112+
113+
if self.name is None:
114+
raise ValueError('The SageMaker model must be created before attempting to delete.')
115+
116+
self.sagemaker_session.delete_model(self.name)

src/sagemaker/predictor.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
5656
self.deserializer = deserializer
5757
self.content_type = content_type or getattr(serializer, 'content_type', None)
5858
self.accept = accept or getattr(deserializer, 'accept', None)
59+
self._endpoint_config_name = self._get_endpoint_config_name()
60+
self._model_names = self._get_model_names()
5961

6062
def predict(self, data, initial_args=None):
6163
"""Return the inference from the specified endpoint.
@@ -109,23 +111,51 @@ def _delete_endpoint_config(self):
109111
"""Delete the Amazon SageMaker endpoint configuration
110112
111113
"""
112-
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
113-
endpoint_config_name = endpoint_description['EndpointConfigName']
114-
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)
114+
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
115115

116116
def delete_endpoint(self, delete_endpoint_config=True):
117-
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor.
117+
"""Delete the Amazon SageMaker endpoint backing this predictor. Also delete the endpoint configuration attached
118+
to it if delete_endpoint_config is True.
118119
119120
Args:
120-
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint
121-
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True)
121+
delete_endpoint_config (bool, optional): Flag to indicate whether to delete endpoint configuration together
122+
with endpoint. Defaults to True. If True, both endpoint and endpoint configuration will be deleted. If
123+
False, only endpoint will be deleted.
122124
123125
"""
124126
if delete_endpoint_config:
125127
self._delete_endpoint_config()
126128

127129
self.sagemaker_session.delete_endpoint(self.endpoint)
128130

131+
def delete_model(self):
132+
"""Deletes the Amazon SageMaker models backing this predictor.
133+
134+
"""
135+
request_failed = False
136+
failed_models = []
137+
for model_name in self._model_names:
138+
try:
139+
self.sagemaker_session.delete_model(model_name)
140+
except Exception: # pylint: disable=broad-except
141+
request_failed = True
142+
failed_models.append(model_name)
143+
144+
if request_failed:
145+
raise Exception('One or more models cannot be deleted, please retry. \n'
146+
'Failed models: {}'.format(', '.join(failed_models)))
147+
148+
def _get_endpoint_config_name(self):
149+
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
150+
endpoint_config_name = endpoint_desc['EndpointConfigName']
151+
return endpoint_config_name
152+
153+
def _get_model_names(self):
154+
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config(
155+
EndpointConfigName=self._endpoint_config_name)
156+
production_variants = endpoint_config['ProductionVariants']
157+
return map(lambda d: d['ModelName'], production_variants)
158+
129159

130160
class _CsvSerializer(object):
131161
def __init__(self):

tests/integ/test_inference_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ def test_inference_pipeline_model_deploy(sagemaker_session):
9292

9393
invalid_data = "1.0,28.0,C,38.0,71.5,1.0"
9494
assert (predictor.predict(invalid_data) is None)
95+
96+
model.delete_model()
97+
with pytest.raises(Exception) as exception:
98+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
99+
assert 'Could not find model' in str(exception.value)

tests/integ/test_kmeans.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def test_kmeans(sagemaker_session):
7575
assert record.label["closest_cluster"] is not None
7676
assert record.label["distance_to_cluster"] is not None
7777

78+
predictor.delete_model()
79+
with pytest.raises(Exception) as exception:
80+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
81+
assert 'Could not find model' in str(exception.value)
82+
7883

7984
def test_async_kmeans(sagemaker_session):
8085
training_job_name = ""

tests/integ/test_mxnet_train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
7272
data = numpy.zeros(shape=(1, 1, 28, 28))
7373
predictor.predict(data)
7474

75+
predictor.delete_model()
76+
with pytest.raises(Exception) as exception:
77+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
78+
assert 'Could not find model' in str(exception.value)
79+
7580

7681
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
7782
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

tests/unit/test_chainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
GPU = 'ml.p2.xlarge'
4646
CPU = 'ml.c4.xlarge'
4747

48+
ENDPOINT_DESC = {
49+
'EndpointConfigName': 'test-endpoint'
50+
}
51+
52+
ENDPOINT_CONFIG_DESC = {
53+
'ProductionVariants': [{'ModelName': 'model-1'},
54+
{'ModelName': 'model-2'}]
55+
}
56+
4857

4958
@pytest.fixture()
5059
def sagemaker_session():
@@ -54,6 +63,8 @@ def sagemaker_session():
5463

5564
describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}
5665
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
66+
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
67+
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5768
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5869
session.expand_role = Mock(name="expand_role", return_value=ROLE)
5970
return session

tests/unit/test_estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@
102102
'ModelDataUrl': MODEL_DATA,
103103
}
104104

105+
ENDPOINT_DESC = {
106+
'EndpointConfigName': 'test-endpoint'
107+
}
108+
109+
ENDPOINT_CONFIG_DESC = {
110+
'ProductionVariants': [{'ModelName': 'model-1'},
111+
{'ModelName': 'model-2'}]
112+
}
113+
105114

106115
class DummyFramework(Framework):
107116
__framework_name__ = 'dummy'
@@ -146,6 +155,8 @@ def sagemaker_session():
146155
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
147156
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
148157
return_value=DESCRIBE_TRAINING_JOB_RESULT)
158+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
159+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
149160
return sms
150161

151162

tests/unit/test_fm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
}
3838
}
3939

40+
ENDPOINT_DESC = {
41+
'EndpointConfigName': 'test-endpoint'
42+
}
43+
44+
ENDPOINT_CONFIG_DESC = {
45+
'ProductionVariants': [{'ModelName': 'model-1'},
46+
{'ModelName': 'model-2'}]
47+
}
48+
4049

4150
@pytest.fixture()
4251
def sagemaker_session():
@@ -47,6 +56,8 @@ def sagemaker_session():
4756
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4857
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4958
return_value=DESCRIBE_TRAINING_JOB_RESULT)
59+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
60+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5061
return sms
5162

5263

tests/unit/test_ipinsights.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@
3939
}
4040
}
4141

42+
ENDPOINT_DESC = {
43+
'EndpointConfigName': 'test-endpoint'
44+
}
45+
46+
ENDPOINT_CONFIG_DESC = {
47+
'ProductionVariants': [{'ModelName': 'model-1'},
48+
{'ModelName': 'model-2'}]
49+
}
50+
4251

4352
@pytest.fixture()
4453
def sagemaker_session():
@@ -49,6 +58,8 @@ def sagemaker_session():
4958
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5059
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
5160
return_value=DESCRIBE_TRAINING_JOB_RESULT)
61+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
62+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5263

5364
return sms
5465

tests/unit/test_kmeans.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
}
3737
}
3838

39+
ENDPOINT_DESC = {
40+
'EndpointConfigName': 'test-endpoint'
41+
}
42+
43+
ENDPOINT_CONFIG_DESC = {
44+
'ProductionVariants': [{'ModelName': 'model-1'},
45+
{'ModelName': 'model-2'}]
46+
}
47+
3948

4049
@pytest.fixture()
4150
def sagemaker_session():
@@ -46,6 +55,8 @@ def sagemaker_session():
4655
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4756
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4857
return_value=DESCRIBE_TRAINING_JOB_RESULT)
58+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
59+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4960

5061
return sms
5162

tests/unit/test_knn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
}
4141
}
4242

43+
ENDPOINT_DESC = {
44+
'EndpointConfigName': 'test-endpoint'
45+
}
46+
47+
ENDPOINT_CONFIG_DESC = {
48+
'ProductionVariants': [{'ModelName': 'model-1'},
49+
{'ModelName': 'model-2'}]
50+
}
51+
4352

4453
@pytest.fixture()
4554
def sagemaker_session():
@@ -50,6 +59,8 @@ def sagemaker_session():
5059
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5160
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
5261
return_value=DESCRIBE_TRAINING_JOB_RESULT)
62+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
63+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5364

5465
return sms
5566

tests/unit/test_lda.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
}
3636
}
3737

38+
ENDPOINT_DESC = {
39+
'EndpointConfigName': 'test-endpoint'
40+
}
41+
42+
ENDPOINT_CONFIG_DESC = {
43+
'ProductionVariants': [{'ModelName': 'model-1'},
44+
{'ModelName': 'model-2'}]
45+
}
46+
3847

3948
@pytest.fixture()
4049
def sagemaker_session():
@@ -44,6 +53,8 @@ def sagemaker_session():
4453
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4554
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4655
return_value=DESCRIBE_TRAINING_JOB_RESULT)
56+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
57+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4758

4859
return sms
4960

tests/unit/test_linear_learner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
}
3838
}
3939

40+
ENDPOINT_DESC = {
41+
'EndpointConfigName': 'test-endpoint'
42+
}
43+
44+
ENDPOINT_CONFIG_DESC = {
45+
'ProductionVariants': [{'ModelName': 'model-1'},
46+
{'ModelName': 'model-2'}]
47+
}
48+
4049

4150
@pytest.fixture()
4251
def sagemaker_session():
@@ -47,6 +56,8 @@ def sagemaker_session():
4756
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4857
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4958
return_value=DESCRIBE_TRAINING_JOB_RESULT)
59+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
60+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5061

5162
return sms
5263

tests/unit/test_mxnet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
CPU_C5 = 'ml.c5.xlarge'
4646
LAUNCH_PS_DISTRIBUTIONS_DICT = {'parameter_server': {'enabled': True}}
4747

48+
ENDPOINT_DESC = {
49+
'EndpointConfigName': 'test-endpoint'
50+
}
51+
52+
ENDPOINT_CONFIG_DESC = {
53+
'ProductionVariants': [{'ModelName': 'model-1'},
54+
{'ModelName': 'model-2'}]
55+
}
56+
4857

4958
@pytest.fixture()
5059
def sagemaker_session():
@@ -55,6 +64,8 @@ def sagemaker_session():
5564
describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}
5665
describe_compilation = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/model_c5.tar.gz'}}
5766
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
67+
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
68+
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5869
session.wait_for_compilation_job = Mock(return_value=describe_compilation)
5970
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6071
session.expand_role = Mock(name="expand_role", return_value=ROLE)

0 commit comments

Comments
 (0)