Skip to content

Commit 0dc63c1

Browse files
committed
fix integ test
1 parent 0df2a9a commit 0dc63c1

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

tests/integ/test_inference_pipeline.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import time
1718

1819
import pytest
1920
from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
@@ -177,22 +178,37 @@ def test_inference_pipeline_model_deploy_with_update_endpoint(sagemaker_session)
177178
models=[sparkml_model, xgb_model],
178179
role="SageMakerRole",
179180
sagemaker_session=sagemaker_session,
180-
name=endpoint_name,
181181
)
182-
model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name)
183-
old_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
182+
model.deploy(1, "ml.t2.medium", endpoint_name=endpoint_name)
183+
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
184+
EndpointName=endpoint_name
185+
)
184186
old_config_name = old_endpoint["EndpointConfigName"]
185187

186188
model.deploy(1, "ml.m4.xlarge", update_endpoint=True, endpoint_name=endpoint_name)
187-
new_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)[
188-
"ProductionVariants"
189-
]
190-
new_production_variants = new_endpoint["ProductionVariants"]
189+
190+
# Wait for endpoint to finish updating
191+
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
192+
current_retry_count = 0
193+
while current_retry_count <= max_retry_count:
194+
if current_retry_count >= max_retry_count:
195+
raise Exception("Endpoint status not 'InService' within expected timeout.")
196+
time.sleep(30)
197+
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
198+
EndpointName=endpoint_name
199+
)
200+
current_retry_count += 1
201+
if new_endpoint["EndpointStatus"] == "InService":
202+
break
203+
191204
new_config_name = new_endpoint["EndpointConfigName"]
205+
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
206+
EndpointConfigName=new_config_name
207+
)
192208

193209
assert old_config_name != new_config_name
194-
assert new_production_variants["InstanceType"] == "ml.m4.xlarge"
195-
assert new_production_variants["InitialInstanceCount"] == 1
210+
assert new_config["ProductionVariants"][0]["InstanceType"] == "ml.m4.xlarge"
211+
assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1
196212

197213
model.delete_model()
198214
with pytest.raises(Exception) as exception:

0 commit comments

Comments
 (0)