Skip to content

Commit dd78d53

Browse files
committed
change: Do lazy initialization of model names and endpoint configuration name in predictor
1 parent 334f942 commit dd78d53

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

src/sagemaker/predictor.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def __init__(
9191
self.sagemaker_session = sagemaker_session or Session()
9292
self.serializer = serializer
9393
self.deserializer = deserializer
94-
self._endpoint_config_name = self._get_endpoint_config_name()
95-
self._model_names = self._get_model_names()
94+
self._endpoint_config_name = None
95+
self._model_names = None
9696
self._context = None
9797

9898
def predict(
@@ -223,6 +223,7 @@ def update_endpoint(
223223
associated with the endpoint.
224224
"""
225225
production_variants = None
226+
current_model_names = self._get_model_names()
226227

227228
if initial_instance_count or instance_type or accelerator_type or model_name:
228229
if instance_type is None or initial_instance_count is None:
@@ -235,12 +236,12 @@ def update_endpoint(
235236
)
236237

237238
if model_name is None:
238-
if len(self._model_names) > 1:
239+
if len(current_model_names) > 1:
239240
raise ValueError(
240241
"Unable to choose a default model for a new EndpointConfig because "
241-
"the endpoint has multiple models: {}".format(", ".join(self._model_names))
242+
"the endpoint has multiple models: {}".format(", ".join(current_model_names))
242243
)
243-
model_name = self._model_names[0]
244+
model_name = current_model_names[0]
244245
else:
245246
self._model_names = [model_name]
246247

@@ -252,9 +253,10 @@ def update_endpoint(
252253
)
253254
production_variants = [production_variant_config]
254255

255-
new_endpoint_config_name = name_from_base(self._endpoint_config_name)
256+
current_endpoint_config_name = self._get_endpoint_config_name()
257+
new_endpoint_config_name = name_from_base(current_endpoint_config_name)
256258
self.sagemaker_session.create_endpoint_config_from_existing(
257-
self._endpoint_config_name,
259+
current_endpoint_config_name,
258260
new_endpoint_config_name,
259261
new_tags=tags,
260262
new_kms_key=kms_key,
@@ -268,7 +270,8 @@ def update_endpoint(
268270

269271
def _delete_endpoint_config(self):
270272
"""Delete the Amazon SageMaker endpoint configuration"""
271-
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
273+
current_endpoint_config_name = self._get_endpoint_config_name()
274+
self.sagemaker_session.delete_endpoint_config(current_endpoint_config_name)
272275

273276
def delete_endpoint(self, delete_endpoint_config=True):
274277
"""Delete the Amazon SageMaker endpoint backing this predictor.
@@ -291,7 +294,8 @@ def delete_model(self):
291294
"""Deletes the Amazon SageMaker models backing this predictor."""
292295
request_failed = False
293296
failed_models = []
294-
for model_name in self._model_names:
297+
current_model_names = self._get_model_names()
298+
for model_name in current_model_names:
295299
try:
296300
self.sagemaker_session.delete_model(model_name)
297301
except Exception: # pylint: disable=broad-except
@@ -467,19 +471,25 @@ def endpoint_context(self):
467471

468472
def _get_endpoint_config_name(self):
469473
"""Placeholder docstring"""
474+
if self._endpoint_config_name is not None:
475+
return self._endpoint_config_name
470476
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(
471477
EndpointName=self.endpoint_name
472478
)
473-
endpoint_config_name = endpoint_desc["EndpointConfigName"]
474-
return endpoint_config_name
479+
self._endpoint_config_name = endpoint_desc["EndpointConfigName"]
480+
return self._endpoint_config_name
475481

476482
def _get_model_names(self):
477483
"""Placeholder docstring"""
484+
if self._model_names is not None:
485+
return self._model_names
486+
current_endpoint_config_name = self._get_endpoint_config_name()
478487
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config(
479-
EndpointConfigName=self._endpoint_config_name
488+
EndpointConfigName=current_endpoint_config_name
480489
)
481490
production_variants = endpoint_config["ProductionVariants"]
482-
return [d["ModelName"] for d in production_variants]
491+
self._model_names = [d["ModelName"] for d in production_variants]
492+
return self._model_names
483493

484494
@property
485495
def content_type(self):

0 commit comments

Comments
 (0)