Skip to content

Commit d229105

Browse files
laurenyunadiaya
andauthored
fix: honor 'wait' flag when updating endpoint (#1222)
Co-authored-by: Nadia Yakimakha <[email protected]>
1 parent bd61feb commit d229105

File tree

8 files changed

+73
-12
lines changed

8 files changed

+73
-12
lines changed

src/sagemaker/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ def deploy(
473473
kms_key=kms_key,
474474
data_capture_config_dict=data_capture_config_dict,
475475
)
476-
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
476+
self.sagemaker_session.update_endpoint(
477+
self.endpoint_name, endpoint_config_name, wait=wait
478+
)
477479
else:
478480
self.sagemaker_session.endpoint_from_production_variants(
479481
name=self.endpoint_name,

src/sagemaker/multidatamodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def deploy(
243243
kms_key=kms_key,
244244
data_capture_config_dict=data_capture_config_dict,
245245
)
246-
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
246+
self.sagemaker_session.update_endpoint(
247+
self.endpoint_name, endpoint_config_name, wait=wait
248+
)
247249
else:
248250
self.sagemaker_session.endpoint_from_production_variants(
249251
name=self.endpoint_name,

src/sagemaker/pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def deploy(
159159
tags=tags,
160160
data_capture_config_dict=data_capture_config_dict,
161161
)
162-
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
162+
self.sagemaker_session.update_endpoint(
163+
self.endpoint_name, endpoint_config_name, wait=wait
164+
)
163165
else:
164166
self.sagemaker_session.endpoint_from_production_variants(
165167
name=self.endpoint_name,

src/sagemaker/session.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,8 +2363,8 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
23632363
self.wait_for_endpoint(endpoint_name)
23642364
return endpoint_name
23652365

2366-
def update_endpoint(self, endpoint_name, endpoint_config_name):
2367-
""" Update an Amazon SageMaker ``Endpoint`` according to the endpoint configuration
2366+
def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True):
2367+
"""Update an Amazon SageMaker ``Endpoint`` according to the endpoint configuration
23682368
specified in the request
23692369
23702370
Raise an error if endpoint with endpoint_name does not exist.
@@ -2373,10 +2373,14 @@ def update_endpoint(self, endpoint_name, endpoint_config_name):
23732373
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
23742374
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to
23752375
deploy.
2376+
wait (bool): Whether to wait for the endpoint deployment to complete before returning
2377+
(default: True).
23762378
23772379
Returns:
23782380
str: Name of the Amazon SageMaker ``Endpoint`` being updated.
23792381
2382+
Raises:
2383+
ValueError: if the endpoint does not already exist
23802384
"""
23812385
if not _deployment_entity_exists(
23822386
lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
@@ -2389,6 +2393,9 @@ def update_endpoint(self, endpoint_name, endpoint_config_name):
23892393
self.sagemaker_client.update_endpoint(
23902394
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
23912395
)
2396+
2397+
if wait:
2398+
self.wait_for_endpoint(endpoint_name)
23922399
return endpoint_name
23932400

23942401
def delete_endpoint(self, endpoint_name):

tests/unit/test_model.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,32 +346,66 @@ def test_deploy_creates_correct_session(local_session, session, tmpdir):
346346
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
347347
def test_deploy_update_endpoint(sagemaker_session, tmpdir):
348348
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
349+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, update_endpoint=True)
350+
sagemaker_session.create_endpoint_config.assert_called_with(
351+
name=model.name,
352+
model_name=model.name,
353+
initial_instance_count=INSTANCE_COUNT,
354+
instance_type=INSTANCE_TYPE,
355+
accelerator_type=None,
356+
tags=None,
357+
kms_key=None,
358+
data_capture_config_dict=None,
359+
)
360+
config_name = sagemaker_session.create_endpoint_config(
361+
name=model.name,
362+
model_name=model.name,
363+
initial_instance_count=INSTANCE_COUNT,
364+
instance_type=INSTANCE_TYPE,
365+
accelerator_type=ACCELERATOR_TYPE,
366+
)
367+
sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True)
368+
sagemaker_session.create_endpoint.assert_not_called()
369+
370+
371+
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
372+
def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir):
349373
endpoint_name = "endpoint-name"
374+
tags = [{"Key": "Value"}]
375+
kms_key = "foo"
376+
data_capture_config = MagicMock()
377+
378+
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
350379
model.deploy(
351380
instance_type=INSTANCE_TYPE,
352381
initial_instance_count=1,
353-
endpoint_name=endpoint_name,
354382
update_endpoint=True,
383+
endpoint_name=endpoint_name,
355384
accelerator_type=ACCELERATOR_TYPE,
385+
tags=tags,
386+
kms_key=kms_key,
387+
wait=False,
388+
data_capture_config=data_capture_config,
356389
)
357390
sagemaker_session.create_endpoint_config.assert_called_with(
358391
name=model.name,
359392
model_name=model.name,
360393
initial_instance_count=INSTANCE_COUNT,
361394
instance_type=INSTANCE_TYPE,
362395
accelerator_type=ACCELERATOR_TYPE,
363-
tags=None,
364-
kms_key=None,
365-
data_capture_config_dict=None,
396+
tags=tags,
397+
kms_key=kms_key,
398+
data_capture_config_dict=data_capture_config._to_request_dict(),
366399
)
367400
config_name = sagemaker_session.create_endpoint_config(
368401
name=model.name,
369402
model_name=model.name,
370403
initial_instance_count=INSTANCE_COUNT,
371404
instance_type=INSTANCE_TYPE,
372405
accelerator_type=ACCELERATOR_TYPE,
406+
wait=False,
373407
)
374-
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name)
408+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False)
375409
sagemaker_session.create_endpoint.assert_not_called()
376410

377411

tests/unit/test_multidatamodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ def test_deploy_model_update(sagemaker_session):
295295
instance_type=INSTANCE_TYPE,
296296
accelerator_type=None,
297297
)
298-
sagemaker_session.update_endpoint.assert_called_with(MULTI_MODEL_ENDPOINT_NAME, config_name)
298+
sagemaker_session.update_endpoint.assert_called_with(
299+
MULTI_MODEL_ENDPOINT_NAME, config_name, wait=True
300+
)
299301
sagemaker_session.create_endpoint.assert_not_called()
300302

301303

tests/unit/test_pipeline_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_deploy_update_endpoint(tfo, time, sagemaker_session):
193193
initial_instance_count=INSTANCE_COUNT,
194194
instance_type=INSTANCE_TYPE,
195195
)
196-
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name)
196+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=True)
197197
sagemaker_session.create_endpoint.assert_not_called()
198198

199199

tests/unit/test_session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,18 @@ def test_update_endpoint_succeed(sagemaker_session):
17851785
assert returned_endpoint_name == endpoint_name
17861786

17871787

1788+
def test_update_endpoint_no_wait(sagemaker_session):
1789+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(
1790+
return_value={"EndpointStatus": "Updating"}
1791+
)
1792+
endpoint_name = "some-endpoint"
1793+
endpoint_config = "some-endpoint-config"
1794+
returned_endpoint_name = sagemaker_session.update_endpoint(
1795+
endpoint_name, endpoint_config, wait=False
1796+
)
1797+
assert returned_endpoint_name == endpoint_name
1798+
1799+
17881800
def test_update_endpoint_non_existing_endpoint(sagemaker_session):
17891801
error = ClientError(
17901802
{"Error": {"Code": "ValidationException", "Message": "Could not find entity"}}, "foo"

0 commit comments

Comments
 (0)