@@ -96,10 +96,6 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
96
96
assert "Could not find model" in str (exception .value )
97
97
98
98
99
- @pytest .mark .skip (
100
- reason = "This test has always failed, but the failure was masked by a bug. "
101
- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
102
- )
103
99
def test_deploy_model_with_tags_and_kms (mxnet_training_job , sagemaker_session , mxnet_full_version ):
104
100
endpoint_name = "test-mxnet-deploy-model-{}" .format (sagemaker_timestamp ())
105
101
@@ -123,18 +119,20 @@ def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, m
123
119
124
120
model .deploy (1 , "ml.m4.xlarge" , endpoint_name = endpoint_name , tags = tags , kms_key = kms_key_arn )
125
121
126
- returned_model = sagemaker_session .describe_model (EndpointName = model .name )
127
- returned_model_tags = sagemaker_session .list_tags (ResourceArn = returned_model [ "ModelArn" ])[
128
- "Tags"
129
- ]
122
+ returned_model = sagemaker_session .sagemaker_client . describe_model (ModelName = model .name )
123
+ returned_model_tags = sagemaker_session .sagemaker_client . list_tags (
124
+ ResourceArn = returned_model [ "ModelArn" ]
125
+ )[ "Tags" ]
130
126
131
- endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
132
- endpoint_tags = sagemaker_session .list_tags (ResourceArn = endpoint ["EndpointArn" ])["Tags" ]
127
+ endpoint = sagemaker_session .sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
128
+ endpoint_tags = sagemaker_session .sagemaker_client .list_tags (
129
+ ResourceArn = endpoint ["EndpointArn" ]
130
+ )["Tags" ]
133
131
134
- endpoint_config = sagemaker_session .describe_endpoint_config (
132
+ endpoint_config = sagemaker_session .sagemaker_client . describe_endpoint_config (
135
133
EndpointConfigName = endpoint ["EndpointConfigName" ]
136
134
)
137
- endpoint_config_tags = sagemaker_session .list_tags (
135
+ endpoint_config_tags = sagemaker_session .sagemaker_client . list_tags (
138
136
ResourceArn = endpoint_config ["EndpointConfigArn" ]
139
137
)["Tags" ]
140
138
@@ -148,10 +146,6 @@ def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, m
148
146
assert endpoint_config ["KmsKeyId" ] == kms_key_arn
149
147
150
148
151
- @pytest .mark .skip (
152
- reason = "This test has always failed, but the failure was masked by a bug. "
153
- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
154
- )
155
149
def test_deploy_model_with_update_endpoint (
156
150
mxnet_training_job , sagemaker_session , mxnet_full_version
157
151
):
@@ -172,26 +166,37 @@ def test_deploy_model_with_update_endpoint(
172
166
framework_version = mxnet_full_version ,
173
167
)
174
168
model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
175
- old_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
169
+ old_endpoint = sagemaker_session .sagemaker_client .describe_endpoint (
170
+ EndpointName = endpoint_name
171
+ )
176
172
old_config_name = old_endpoint ["EndpointConfigName" ]
177
173
178
174
model .deploy (1 , "ml.m4.xlarge" , update_endpoint = True , endpoint_name = endpoint_name )
179
- new_endpoint = sagemaker_session .describe_endpoint (EndpointName = endpoint_name )[
180
- "ProductionVariants"
181
- ]
182
- new_production_variants = new_endpoint ["ProductionVariants" ]
175
+
176
+ # Wait for endpoint to finish updating
177
+ max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
178
+ current_retry_count = 0
179
+ while current_retry_count <= max_retry_count :
180
+ if current_retry_count >= max_retry_count :
181
+ raise Exception ("Endpoint status not 'InService' within expected timeout." )
182
+ time .sleep (30 )
183
+ new_endpoint = sagemaker_session .sagemaker_client .describe_endpoint (
184
+ EndpointName = endpoint_name
185
+ )
186
+ current_retry_count += 1
187
+ if new_endpoint ["EndpointStatus" ] == "InService" :
188
+ break
189
+
183
190
new_config_name = new_endpoint ["EndpointConfigName" ]
191
+ new_config = sagemaker_session .sagemaker_client .describe_endpoint_config (
192
+ EndpointConfigName = new_config_name
193
+ )
184
194
185
195
assert old_config_name != new_config_name
186
- assert new_production_variants ["InstanceType" ] == "ml.m4.xlarge"
187
- assert new_production_variants ["InitialInstanceCount" ] == 1
188
- assert new_production_variants ["AcceleratorType" ] is None
196
+ assert new_config ["ProductionVariants" ][0 ]["InstanceType" ] == "ml.m4.xlarge"
197
+ assert new_config ["ProductionVariants" ][0 ]["InitialInstanceCount" ] == 1
189
198
190
199
191
- @pytest .mark .skip (
192
- reason = "This test has always failed, but the failure was masked by a bug. "
193
- "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968"
194
- )
195
200
def test_deploy_model_with_update_non_existing_endpoint (
196
201
mxnet_training_job , sagemaker_session , mxnet_full_version
197
202
):
@@ -216,7 +221,7 @@ def test_deploy_model_with_update_non_existing_endpoint(
216
221
framework_version = mxnet_full_version ,
217
222
)
218
223
model .deploy (1 , "ml.t2.medium" , endpoint_name = endpoint_name )
219
- sagemaker_session .describe_endpoint (EndpointName = endpoint_name )
224
+ sagemaker_session .sagemaker_client . describe_endpoint (EndpointName = endpoint_name )
220
225
221
226
with pytest .raises (ValueError , message = expected_error_message ):
222
227
model .deploy (
0 commit comments