@@ -91,12 +91,17 @@ def __init__(
91
91
self .sagemaker_session = sagemaker_session or Session ()
92
92
self .serializer = serializer
93
93
self .deserializer = deserializer
94
- self ._endpoint_config_name = self . _get_endpoint_config_name ()
95
- self ._model_names = self . _get_model_names ()
94
+ self ._endpoint_config_name = None
95
+ self ._model_names = None
96
96
self ._context = None
97
97
98
98
def predict (
99
- self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
99
+ self ,
100
+ data ,
101
+ initial_args = None ,
102
+ target_model = None ,
103
+ target_variant = None ,
104
+ inference_id = None ,
100
105
):
101
106
"""Return the inference from the specified endpoint.
102
107
@@ -138,7 +143,12 @@ def _handle_response(self, response):
138
143
return self .deserializer .deserialize (response_body , content_type )
139
144
140
145
def _create_request_args (
141
- self , data , initial_args = None , target_model = None , target_variant = None , inference_id = None
146
+ self ,
147
+ data ,
148
+ initial_args = None ,
149
+ target_model = None ,
150
+ target_variant = None ,
151
+ inference_id = None ,
142
152
):
143
153
"""Placeholder docstring"""
144
154
args = dict (initial_args ) if initial_args else {}
@@ -223,24 +233,30 @@ def update_endpoint(
223
233
associated with the endpoint.
224
234
"""
225
235
production_variants = None
236
+ current_model_names = self ._get_model_names ()
226
237
227
238
if initial_instance_count or instance_type or accelerator_type or model_name :
228
239
if instance_type is None or initial_instance_count is None :
229
240
raise ValueError (
230
241
"Missing initial_instance_count and/or instance_type. Provided values: "
231
242
"initial_instance_count={}, instance_type={}, accelerator_type={}, "
232
243
"model_name={}." .format (
233
- initial_instance_count , instance_type , accelerator_type , model_name
244
+ initial_instance_count ,
245
+ instance_type ,
246
+ accelerator_type ,
247
+ model_name ,
234
248
)
235
249
)
236
250
237
251
if model_name is None :
238
- if len (self . _model_names ) > 1 :
252
+ if len (current_model_names ) > 1 :
239
253
raise ValueError (
240
254
"Unable to choose a default model for a new EndpointConfig because "
241
- "the endpoint has multiple models: {}" .format (", " .join (self ._model_names ))
255
+ "the endpoint has multiple models: {}" .format (
256
+ ", " .join (current_model_names )
257
+ )
242
258
)
243
- model_name = self . _model_names [0 ]
259
+ model_name = current_model_names [0 ]
244
260
else :
245
261
self ._model_names = [model_name ]
246
262
@@ -252,9 +268,10 @@ def update_endpoint(
252
268
)
253
269
production_variants = [production_variant_config ]
254
270
255
- new_endpoint_config_name = name_from_base (self ._endpoint_config_name )
271
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
272
+ new_endpoint_config_name = name_from_base (current_endpoint_config_name )
256
273
self .sagemaker_session .create_endpoint_config_from_existing (
257
- self . _endpoint_config_name ,
274
+ current_endpoint_config_name ,
258
275
new_endpoint_config_name ,
259
276
new_tags = tags ,
260
277
new_kms_key = kms_key ,
@@ -268,7 +285,8 @@ def update_endpoint(
268
285
269
286
def _delete_endpoint_config (self ):
270
287
"""Delete the Amazon SageMaker endpoint configuration"""
271
- self .sagemaker_session .delete_endpoint_config (self ._endpoint_config_name )
288
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
289
+ self .sagemaker_session .delete_endpoint_config (current_endpoint_config_name )
272
290
273
291
def delete_endpoint (self , delete_endpoint_config = True ):
274
292
"""Delete the Amazon SageMaker endpoint backing this predictor.
@@ -291,7 +309,8 @@ def delete_model(self):
291
309
"""Deletes the Amazon SageMaker models backing this predictor."""
292
310
request_failed = False
293
311
failed_models = []
294
- for model_name in self ._model_names :
312
+ current_model_names = self ._get_model_names ()
313
+ for model_name in current_model_names :
295
314
try :
296
315
self .sagemaker_session .delete_model (model_name )
297
316
except Exception : # pylint: disable=broad-except
@@ -460,26 +479,33 @@ def endpoint_context(self):
460
479
if len (contexts ) != 0 :
461
480
# create endpoint context object
462
481
self ._context = EndpointContext .load (
463
- sagemaker_session = self .sagemaker_session , context_name = contexts [0 ].context_name
482
+ sagemaker_session = self .sagemaker_session ,
483
+ context_name = contexts [0 ].context_name ,
464
484
)
465
485
466
486
return self ._context
467
487
468
488
def _get_endpoint_config_name (self ):
469
489
"""Placeholder docstring"""
490
+ if self ._endpoint_config_name is not None :
491
+ return self ._endpoint_config_name
470
492
endpoint_desc = self .sagemaker_session .sagemaker_client .describe_endpoint (
471
493
EndpointName = self .endpoint_name
472
494
)
473
- endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
474
- return endpoint_config_name
495
+ self . _endpoint_config_name = endpoint_desc ["EndpointConfigName" ]
496
+ return self . _endpoint_config_name
475
497
476
498
def _get_model_names (self ):
477
499
"""Placeholder docstring"""
500
+ if self ._model_names is not None :
501
+ return self ._model_names
502
+ current_endpoint_config_name = self ._get_endpoint_config_name ()
478
503
endpoint_config = self .sagemaker_session .sagemaker_client .describe_endpoint_config (
479
- EndpointConfigName = self . _endpoint_config_name
504
+ EndpointConfigName = current_endpoint_config_name
480
505
)
481
506
production_variants = endpoint_config ["ProductionVariants" ]
482
- return [d ["ModelName" ] for d in production_variants ]
507
+ self ._model_names = [d ["ModelName" ] for d in production_variants ]
508
+ return self ._model_names
483
509
484
510
@property
485
511
def content_type (self ):
0 commit comments