@@ -91,8 +91,8 @@ def __init__(
91
91
self .sagemaker_session = sagemaker_session or Session ()
92
92
self .serializer = serializer
93
93
self .deserializer = deserializer
94
- self ._endpoint_config_name = self . _get_endpoint_config_name ()
95
- self ._model_names = self . _get_model_names ()
94
+ self ._endpoint_config_name = None
95
+ self ._model_names = None
96
96
self ._context = None
97
97
98
98
def predict (
@@ -223,6 +223,7 @@ def update_endpoint(
223
223
associated with the endpoint.
224
224
"""
225
225
production_variants = None
226
+ current_model_names = self ._get_model_names ()
226
227
227
228
if initial_instance_count or instance_type or accelerator_type or model_name :
228
229
if instance_type is None or initial_instance_count is None :
@@ -235,12 +236,12 @@ def update_endpoint(
235
236
)
236
237
237
238
if model_name is None :
238
- if len (self . _model_names ) > 1 :
239
+ if len (current_model_names ) > 1 :
239
240
raise ValueError (
240
241
"Unable to choose a default model for a new EndpointConfig because "
241
- "the endpoint has multiple models: {}" .format (", " .join (self . _model_names ))
242
+ "the endpoint has multiple models: {}" .format (", " .join (current_model_names ))
242
243
)
243
- model_name = self . _model_names [0 ]
244
+ model_name = current_model_names [0 ]
244
245
else :
245
246
self ._model_names = [model_name ]
246
247
@@ -252,9 +253,10 @@ def update_endpoint(
252
253
)
253
254
production_variants = [production_variant_config ]
254
255
255
- new_endpoint_config_name = name_from_base (self ._endpoint_config_name )
256
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
257
+ new_endpoint_config_name = name_from_base (current_endpoint_config_name )
256
258
self .sagemaker_session .create_endpoint_config_from_existing (
257
- self . _endpoint_config_name ,
259
+ current_endpoint_config_name ,
258
260
new_endpoint_config_name ,
259
261
new_tags = tags ,
260
262
new_kms_key = kms_key ,
@@ -268,7 +270,8 @@ def update_endpoint(
268
270
269
271
def _delete_endpoint_config (self ):
270
272
"""Delete the Amazon SageMaker endpoint configuration"""
271
- self .sagemaker_session .delete_endpoint_config (self ._endpoint_config_name )
273
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
274
+ self .sagemaker_session .delete_endpoint_config (current_endpoint_config_name )
272
275
273
276
def delete_endpoint (self , delete_endpoint_config = True ):
274
277
"""Delete the Amazon SageMaker endpoint backing this predictor.
@@ -291,7 +294,8 @@ def delete_model(self):
291
294
"""Deletes the Amazon SageMaker models backing this predictor."""
292
295
request_failed = False
293
296
failed_models = []
294
- for model_name in self ._model_names :
297
+ current_model_names = self ._get_model_names ()
298
+ for model_name in current_model_names :
295
299
try :
296
300
self .sagemaker_session .delete_model (model_name )
297
301
except Exception : # pylint: disable=broad-except
@@ -467,19 +471,25 @@ def endpoint_context(self):
467
471
468
472
def _get_endpoint_config_name (self ):
469
473
"""Placeholder docstring"""
474
+ if self ._endpoint_config_name is not None :
475
+ return self ._endpoint_config_name
470
476
endpoint_desc = self .sagemaker_session .sagemaker_client .describe_endpoint (
471
477
EndpointName = self .endpoint_name
472
478
)
473
- endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
474
- return endpoint_config_name
479
+ self . _endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
480
+ return self . _endpoint_config_name
475
481
476
482
def _get_model_names (self ):
477
483
"""Placeholder docstring"""
484
+ if self ._model_names is not None :
485
+ return self ._model_names
486
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
478
487
endpoint_config = self .sagemaker_session .sagemaker_client .describe_endpoint_config (
479
- EndpointConfigName = self . _endpoint_config_name
488
+ EndpointConfigName = current_endpoint_config_name
480
489
)
481
490
production_variants = endpoint_config ["ProductionVariants" ]
482
- return [d ["ModelName" ] for d in production_variants ]
491
+ self ._model_names = [d ["ModelName" ] for d in production_variants ]
492
+ return self ._model_names
483
493
484
494
@property
485
495
def content_type (self ):
0 commit comments