Skip to content

Commit f8fc9b1

Browse files
jinpengqigwang111
authored andcommitted
Modify to check recommendations job res first.
1 parent 4d2c1ba commit f8fc9b1

File tree

2 files changed

+69
-55
lines changed

2 files changed

+69
-55
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 66 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -380,24 +380,18 @@ def _update_params_for_recommendation_id(
380380

381381
# Validate recommendation id
382382
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
383-
raise ValueError("Inference Recommendation id is not valid")
383+
raise ValueError("inference_recommendation_id is not valid")
384384
job_or_model_name = inference_recommendation_id.split("/")[0]
385385

386386
sage_client = self.sagemaker_session.sagemaker_client
387-
388-
# Desribe inference recommendation job and model details
389-
recommendation_res, model_res = self._describe_recommendation_job_and_model(
390-
sage_client=sage_client,
391-
job_or_model_name=job_or_model_name,
392-
)
393-
394-
# Search the recommendation from above describe results
387+
# Get recommendation from right size job and model
395388
(
396389
right_size_recommendation,
397390
model_recommendation,
398-
) = self._get_right_size_and_model_recommendation(
399-
recommendation_res=recommendation_res,
400-
model_res=model_res,
391+
right_size_job_res,
392+
) = self._get_recommendation(
393+
sage_client=sage_client,
394+
job_or_model_name=job_or_model_name,
401395
inference_recommendation_id=inference_recommendation_id,
402396
)
403397

@@ -418,7 +412,7 @@ def _update_params_for_recommendation_id(
418412
"since they are in recommendation, or specify both of them if you want"
419413
"to override the recommendation."
420414
)
421-
input_config = recommendation_res["InputConfig"]
415+
input_config = right_size_job_res["InputConfig"]
422416
model_config = right_size_recommendation["ModelConfiguration"]
423417
envs = (
424418
model_config["EnvironmentParameters"]
@@ -540,56 +534,76 @@ def _convert_to_stopping_conditions_json(
540534
]
541535
return stopping_conditions
542536

543-
def _get_right_size_and_model_recommendation(
544-
self,
545-
model_res=None,
546-
recommendation_res=None,
547-
inference_recommendation_id=None,
548-
):
549-
"""Get recommendation from right size job or model"""
550-
right_size_recommendation, model_recommendation = None, None
551-
if recommendation_res:
552-
right_size_recommendation = self._get_recommendation(
553-
recommendation_list=recommendation_res["InferenceRecommendations"],
554-
inference_recommendation_id=inference_recommendation_id,
555-
)
556-
if model_res:
557-
model_recommendation = self._get_recommendation(
558-
recommendation_list=model_res["DeploymentRecommendation"][
559-
"RealTimeInferenceRecommendations"
560-
],
537+
def _get_recommendation(self, sage_client, job_or_model_name, inference_recommendation_id):
538+
"""Get recommendation from right size job and model"""
539+
right_size_recommendation, model_recommendation, right_size_job_res = None, None, None
540+
right_size_recommendation, right_size_job_res = self._get_right_size_recommendation(
541+
sage_client=sage_client,
542+
job_or_model_name=job_or_model_name,
543+
inference_recommendation_id=inference_recommendation_id,
544+
)
545+
if right_size_recommendation is None:
546+
model_recommendation = self._get_model_recommendation(
547+
sage_client=sage_client,
548+
job_or_model_name=job_or_model_name,
561549
inference_recommendation_id=inference_recommendation_id,
562550
)
563-
if right_size_recommendation is None and model_recommendation is None:
564-
raise ValueError("Inference Recommendation id is not valid")
551+
if model_recommendation is None:
552+
raise ValueError("inference_recommendation_id is not valid")
565553

566-
return right_size_recommendation, model_recommendation
554+
return right_size_recommendation, model_recommendation, right_size_job_res
567555

568-
def _get_recommendation(self, recommendation_list, inference_recommendation_id):
569-
"""Get recommendation based on recommendation id"""
570-
return next(
571-
(
572-
rec
573-
for rec in recommendation_list
574-
if rec["RecommendationId"] == inference_recommendation_id
575-
),
576-
None,
577-
)
578-
579-
def _describe_recommendation_job_and_model(self, sage_client, job_or_model_name):
580-
"""Describe inference recommendation job and model results"""
581-
recommendation_res, model_res = None, None
556+
def _get_right_size_recommendation(
557+
self,
558+
sage_client,
559+
job_or_model_name,
560+
inference_recommendation_id,
561+
):
562+
"""Get recommendation from right size job"""
563+
right_size_recommendation, right_size_job_res = None, None
582564
try:
583-
recommendation_res = sage_client.describe_inference_recommendations_job(
565+
right_size_job_res = sage_client.describe_inference_recommendations_job(
584566
JobName=job_or_model_name
585567
)
568+
if right_size_job_res:
569+
right_size_recommendation = self._search_recommendation(
570+
recommendation_list=right_size_job_res["InferenceRecommendations"],
571+
inference_recommendation_id=inference_recommendation_id,
572+
)
586573
except sage_client.exceptions.ResourceNotFound:
587574
pass
575+
576+
return right_size_recommendation, right_size_job_res
577+
578+
def _get_model_recommendation(
579+
self,
580+
sage_client,
581+
job_or_model_name,
582+
inference_recommendation_id,
583+
):
584+
"""Get recommendation from model"""
585+
model_recommendation = None
588586
try:
589587
model_res = sage_client.describe_model(ModelName=job_or_model_name)
588+
if model_res:
589+
model_recommendation = self._search_recommendation(
590+
recommendation_list=model_res["DeploymentRecommendation"][
591+
"RealTimeInferenceRecommendations"
592+
],
593+
inference_recommendation_id=inference_recommendation_id,
594+
)
590595
except sage_client.exceptions.ResourceNotFound:
591596
pass
592-
if recommendation_res is None and model_res is None:
593-
raise ValueError("Inference Recommendation id is not valid")
594597

595-
return recommendation_res, model_res
598+
return model_recommendation
599+
600+
def _search_recommendation(self, recommendation_list, inference_recommendation_id):
601+
"""Search recommendation based on recommendation id"""
602+
return next(
603+
(
604+
rec
605+
for rec in recommendation_list
606+
if rec["RecommendationId"] == inference_recommendation_id
607+
),
608+
None,
609+
)

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def test_deploy_ir_with_incompatible_parameters(sagemaker_session):
605605
def test_deploy_with_wrong_recommendation_id(sagemaker_session):
606606
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
607607

608-
with pytest.raises(ValueError, match="Inference Recommendation id is not valid"):
608+
with pytest.raises(ValueError, match="inference_recommendation_id is not valid"):
609609
model.deploy(
610610
inference_recommendation_id=INVALID_RECOMMENDATION_ID,
611611
)
@@ -713,7 +713,7 @@ def test_deploy_with_invalid_inference_recommendation_id(sagemaker_session):
713713

714714
with pytest.raises(
715715
ValueError,
716-
match="Inference Recommendation id is not valid",
716+
match="inference_recommendation_id is not valid",
717717
):
718718
model.deploy(
719719
inference_recommendation_id=NOT_EXISTED_RECOMMENDATION_ID,
@@ -728,7 +728,7 @@ def test_deploy_with_invalid_model_recommendation_id(sagemaker_session):
728728

729729
with pytest.raises(
730730
ValueError,
731-
match="Inference Recommendation id is not valid",
731+
match="inference_recommendation_id is not valid",
732732
):
733733
model.deploy(
734734
inference_recommendation_id=NOT_EXISTED_MODEL_RECOMMENDATION_ID,

0 commit comments

Comments
 (0)