Skip to content

Commit 0918810

Browse files
committed
Modify to check recommendations job res first.
1 parent 1179dc0 commit 0918810

File tree

2 files changed

+69
-55
lines changed

2 files changed

+69
-55
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -336,24 +336,18 @@ def _update_params_for_recommendation_id(
336336

337337
# Validate recommendation id
338338
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
339-
raise ValueError("Inference Recommendation id is not valid")
339+
raise ValueError("inference_recommendation_id is not valid")
340340
job_or_model_name = inference_recommendation_id.split("/")[0]
341341

342342
sage_client = self.sagemaker_session.sagemaker_client
343-
344-
# Desribe inference recommendation job and model details
345-
recommendation_res, model_res = self._describe_recommendation_job_and_model(
346-
sage_client=sage_client,
347-
job_or_model_name=job_or_model_name,
348-
)
349-
350-
# Search the recommendation from above describe results
343+
# Get recommendation from right size job and model
351344
(
352345
right_size_recommendation,
353346
model_recommendation,
354-
) = self._get_right_size_and_model_recommendation(
355-
recommendation_res=recommendation_res,
356-
model_res=model_res,
347+
right_size_job_res,
348+
) = self._get_recommendation(
349+
sage_client=sage_client,
350+
job_or_model_name=job_or_model_name,
357351
inference_recommendation_id=inference_recommendation_id,
358352
)
359353

@@ -374,7 +368,7 @@ def _update_params_for_recommendation_id(
374368
"since they are in recommendation, or specify both of them if you want"
375369
"to override the recommendation."
376370
)
377-
input_config = recommendation_res["InputConfig"]
371+
input_config = right_size_job_res["InputConfig"]
378372
model_config = right_size_recommendation["ModelConfiguration"]
379373
envs = (
380374
model_config["EnvironmentParameters"]
@@ -498,56 +492,76 @@ def _convert_to_stopping_conditions_json(
498492
]
499493
return stopping_conditions
500494

501-
def _get_right_size_and_model_recommendation(
502-
self,
503-
model_res=None,
504-
recommendation_res=None,
505-
inference_recommendation_id=None,
506-
):
507-
"""Get recommendation from right size job or model"""
508-
right_size_recommendation, model_recommendation = None, None
509-
if recommendation_res:
510-
right_size_recommendation = self._get_recommendation(
511-
recommendation_list=recommendation_res["InferenceRecommendations"],
512-
inference_recommendation_id=inference_recommendation_id,
513-
)
514-
if model_res:
515-
model_recommendation = self._get_recommendation(
516-
recommendation_list=model_res["DeploymentRecommendation"][
517-
"RealTimeInferenceRecommendations"
518-
],
495+
def _get_recommendation(self, sage_client, job_or_model_name, inference_recommendation_id):
496+
"""Get recommendation from right size job and model"""
497+
right_size_recommendation, model_recommendation, right_size_job_res = None, None, None
498+
right_size_recommendation, right_size_job_res = self._get_right_size_recommendation(
499+
sage_client=sage_client,
500+
job_or_model_name=job_or_model_name,
501+
inference_recommendation_id=inference_recommendation_id,
502+
)
503+
if right_size_recommendation is None:
504+
model_recommendation = self._get_model_recommendation(
505+
sage_client=sage_client,
506+
job_or_model_name=job_or_model_name,
519507
inference_recommendation_id=inference_recommendation_id,
520508
)
521-
if right_size_recommendation is None and model_recommendation is None:
522-
raise ValueError("Inference Recommendation id is not valid")
509+
if model_recommendation is None:
510+
raise ValueError("inference_recommendation_id is not valid")
523511

524-
return right_size_recommendation, model_recommendation
512+
return right_size_recommendation, model_recommendation, right_size_job_res
525513

526-
def _get_recommendation(self, recommendation_list, inference_recommendation_id):
527-
"""Get recommendation based on recommendation id"""
528-
return next(
529-
(
530-
rec
531-
for rec in recommendation_list
532-
if rec["RecommendationId"] == inference_recommendation_id
533-
),
534-
None,
535-
)
536-
537-
def _describe_recommendation_job_and_model(self, sage_client, job_or_model_name):
538-
"""Describe inference recommendation job and model results"""
539-
recommendation_res, model_res = None, None
514+
def _get_right_size_recommendation(
515+
self,
516+
sage_client,
517+
job_or_model_name,
518+
inference_recommendation_id,
519+
):
520+
"""Get recommendation from right size job"""
521+
right_size_recommendation, right_size_job_res = None, None
540522
try:
541-
recommendation_res = sage_client.describe_inference_recommendations_job(
523+
right_size_job_res = sage_client.describe_inference_recommendations_job(
542524
JobName=job_or_model_name
543525
)
526+
if right_size_job_res:
527+
right_size_recommendation = self._search_recommendation(
528+
recommendation_list=right_size_job_res["InferenceRecommendations"],
529+
inference_recommendation_id=inference_recommendation_id,
530+
)
544531
except sage_client.exceptions.ResourceNotFound:
545532
pass
533+
534+
return right_size_recommendation, right_size_job_res
535+
536+
def _get_model_recommendation(
537+
self,
538+
sage_client,
539+
job_or_model_name,
540+
inference_recommendation_id,
541+
):
542+
"""Get recommendation from model"""
543+
model_recommendation = None
546544
try:
547545
model_res = sage_client.describe_model(ModelName=job_or_model_name)
546+
if model_res:
547+
model_recommendation = self._search_recommendation(
548+
recommendation_list=model_res["DeploymentRecommendation"][
549+
"RealTimeInferenceRecommendations"
550+
],
551+
inference_recommendation_id=inference_recommendation_id,
552+
)
548553
except sage_client.exceptions.ResourceNotFound:
549554
pass
550-
if recommendation_res is None and model_res is None:
551-
raise ValueError("Inference Recommendation id is not valid")
552555

553-
return recommendation_res, model_res
556+
return model_recommendation
557+
558+
def _search_recommendation(self, recommendation_list, inference_recommendation_id):
559+
"""Search recommendation based on recommendation id"""
560+
return next(
561+
(
562+
rec
563+
for rec in recommendation_list
564+
if rec["RecommendationId"] == inference_recommendation_id
565+
),
566+
None,
567+
)

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def test_deploy_ir_with_incompatible_parameters(sagemaker_session):
532532
def test_deploy_with_wrong_recommendation_id(sagemaker_session):
533533
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
534534

535-
with pytest.raises(ValueError, match="Inference Recommendation id is not valid"):
535+
with pytest.raises(ValueError, match="inference_recommendation_id is not valid"):
536536
model.deploy(
537537
inference_recommendation_id=INVALID_RECOMMENDATION_ID,
538538
)
@@ -640,7 +640,7 @@ def test_deploy_with_invalid_inference_recommendation_id(sagemaker_session):
640640

641641
with pytest.raises(
642642
ValueError,
643-
match="Inference Recommendation id is not valid",
643+
match="inference_recommendation_id is not valid",
644644
):
645645
model.deploy(
646646
inference_recommendation_id=NOT_EXISTED_RECOMMENDATION_ID,
@@ -655,7 +655,7 @@ def test_deploy_with_invalid_model_recommendation_id(sagemaker_session):
655655

656656
with pytest.raises(
657657
ValueError,
658-
match="Inference Recommendation id is not valid",
658+
match="inference_recommendation_id is not valid",
659659
):
660660
model.deploy(
661661
inference_recommendation_id=NOT_EXISTED_MODEL_RECOMMENDATION_ID,

0 commit comments

Comments
 (0)