Skip to content

Commit dacacb6

Browse files
pinarawsahsan-z-khanDan
authored
change: Do lazy initialization in predictor (#2206)
Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Dan <[email protected]>
1 parent 15c5c6a commit dacacb6

File tree

3 files changed

+46
-18
lines changed

3 files changed

+46
-18
lines changed

src/sagemaker/predictor.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,17 @@ def __init__(
9191
self.sagemaker_session = sagemaker_session or Session()
9292
self.serializer = serializer
9393
self.deserializer = deserializer
94-
self._endpoint_config_name = self._get_endpoint_config_name()
95-
self._model_names = self._get_model_names()
94+
self._endpoint_config_name = None
95+
self._model_names = None
9696
self._context = None
9797

9898
def predict(
99-
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
99+
self,
100+
data,
101+
initial_args=None,
102+
target_model=None,
103+
target_variant=None,
104+
inference_id=None,
100105
):
101106
"""Return the inference from the specified endpoint.
102107
@@ -138,7 +143,12 @@ def _handle_response(self, response):
138143
return self.deserializer.deserialize(response_body, content_type)
139144

140145
def _create_request_args(
141-
self, data, initial_args=None, target_model=None, target_variant=None, inference_id=None
146+
self,
147+
data,
148+
initial_args=None,
149+
target_model=None,
150+
target_variant=None,
151+
inference_id=None,
142152
):
143153
"""Placeholder docstring"""
144154
args = dict(initial_args) if initial_args else {}
@@ -223,24 +233,30 @@ def update_endpoint(
223233
associated with the endpoint.
224234
"""
225235
production_variants = None
236+
current_model_names = self._get_model_names()
226237

227238
if initial_instance_count or instance_type or accelerator_type or model_name:
228239
if instance_type is None or initial_instance_count is None:
229240
raise ValueError(
230241
"Missing initial_instance_count and/or instance_type. Provided values: "
231242
"initial_instance_count={}, instance_type={}, accelerator_type={}, "
232243
"model_name={}.".format(
233-
initial_instance_count, instance_type, accelerator_type, model_name
244+
initial_instance_count,
245+
instance_type,
246+
accelerator_type,
247+
model_name,
234248
)
235249
)
236250

237251
if model_name is None:
238-
if len(self._model_names) > 1:
252+
if len(current_model_names) > 1:
239253
raise ValueError(
240254
"Unable to choose a default model for a new EndpointConfig because "
241-
"the endpoint has multiple models: {}".format(", ".join(self._model_names))
255+
"the endpoint has multiple models: {}".format(
256+
", ".join(current_model_names)
257+
)
242258
)
243-
model_name = self._model_names[0]
259+
model_name = current_model_names[0]
244260
else:
245261
self._model_names = [model_name]
246262

@@ -252,9 +268,10 @@ def update_endpoint(
252268
)
253269
production_variants = [production_variant_config]
254270

255-
new_endpoint_config_name = name_from_base(self._endpoint_config_name)
271+
current_endpoint_config_name = self._get_endpoint_config_name()
272+
new_endpoint_config_name = name_from_base(current_endpoint_config_name)
256273
self.sagemaker_session.create_endpoint_config_from_existing(
257-
self._endpoint_config_name,
274+
current_endpoint_config_name,
258275
new_endpoint_config_name,
259276
new_tags=tags,
260277
new_kms_key=kms_key,
@@ -268,7 +285,8 @@ def update_endpoint(
268285

269286
def _delete_endpoint_config(self):
270287
"""Delete the Amazon SageMaker endpoint configuration"""
271-
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
288+
current_endpoint_config_name = self._get_endpoint_config_name()
289+
self.sagemaker_session.delete_endpoint_config(current_endpoint_config_name)
272290

273291
def delete_endpoint(self, delete_endpoint_config=True):
274292
"""Delete the Amazon SageMaker endpoint backing this predictor.
@@ -291,7 +309,8 @@ def delete_model(self):
291309
"""Deletes the Amazon SageMaker models backing this predictor."""
292310
request_failed = False
293311
failed_models = []
294-
for model_name in self._model_names:
312+
current_model_names = self._get_model_names()
313+
for model_name in current_model_names:
295314
try:
296315
self.sagemaker_session.delete_model(model_name)
297316
except Exception: # pylint: disable=broad-except
@@ -460,26 +479,33 @@ def endpoint_context(self):
460479
if len(contexts) != 0:
461480
# create endpoint context object
462481
self._context = EndpointContext.load(
463-
sagemaker_session=self.sagemaker_session, context_name=contexts[0].context_name
482+
sagemaker_session=self.sagemaker_session,
483+
context_name=contexts[0].context_name,
464484
)
465485

466486
return self._context
467487

468488
def _get_endpoint_config_name(self):
469489
"""Placeholder docstring"""
490+
if self._endpoint_config_name is not None:
491+
return self._endpoint_config_name
470492
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(
471493
EndpointName=self.endpoint_name
472494
)
473-
endpoint_config_name = endpoint_desc["EndpointConfigName"]
474-
return endpoint_config_name
495+
self._endpoint_config_name = endpoint_desc["EndpointConfigName"]
496+
return self._endpoint_config_name
475497

476498
def _get_model_names(self):
477499
"""Placeholder docstring"""
500+
if self._model_names is not None:
501+
return self._model_names
502+
current_endpoint_config_name = self._get_endpoint_config_name()
478503
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config(
479-
EndpointConfigName=self._endpoint_config_name
504+
EndpointConfigName=current_endpoint_config_name
480505
)
481506
production_variants = endpoint_config["ProductionVariants"]
482-
return [d["ModelName"] for d in production_variants]
507+
self._model_names = [d["ModelName"] for d in production_variants]
508+
return self._model_names
483509

484510
@property
485511
def content_type(self):

tests/integ/test_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _deploy_estimator_and_assert_instance_type(estimator, instance_type):
9797
try:
9898
predictor = estimator.deploy(1, instance_type)
9999

100-
model_name = predictor._model_names[0]
100+
model_name = predictor._get_model_names()[0]
101101
config_name = sagemaker_session.sagemaker_client.describe_endpoint(
102102
EndpointName=predictor.endpoint_name
103103
)["EndpointConfigName"]

tests/unit/test_predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def test_predict_call_pass_through():
6262
result = predictor.predict(data)
6363

6464
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
65+
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
66+
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
6567

6668
expected_request_args = {
6769
"Accept": DEFAULT_ACCEPT,

0 commit comments

Comments
 (0)