Skip to content

Commit a7cc844

Browse files
jinpengqigwang111
authored andcommitted
Refactor get_recommendation
1 parent 57fe764 commit a7cc844

File tree

3 files changed

+93
-63
lines changed

3 files changed

+93
-63
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -381,58 +381,34 @@ def _update_params_for_recommendation_id(
381381
# Validate recommendation id
382382
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
383383
raise ValueError("Inference Recommendation id is not valid")
384-
recommendation_job_name = inference_recommendation_id.split("/")[0]
384+
job_or_model_name = inference_recommendation_id.split("/")[0]
385385

386386
sage_client = self.sagemaker_session.sagemaker_client
387387

388-
# Retrieve model or inference recommendation job details
389-
recommendation_res, model_res = None, None
390-
try:
391-
recommendation_res = sage_client.describe_inference_recommendations_job(
392-
JobName=recommendation_job_name
393-
)
394-
except sage_client.exceptions.ResourceNotFound:
395-
pass
396-
try:
397-
model_res = sage_client.describe_model(ModelName=recommendation_job_name)
398-
except sage_client.exceptions.ResourceNotFound:
399-
pass
400-
if recommendation_res is None and model_res is None:
401-
raise ValueError("Inference Recommendation id is not valid")
388+
# Desribe inference recommendation job and model details
389+
recommendation_res, model_res = self._describe_recommendation_job_and_model(
390+
sage_client=sage_client,
391+
job_or_model_name=job_or_model_name,
392+
)
402393

403-
# Search the recommendation from above describe result lists
404-
inference_recommendation, instant_recommendation = None, None
405-
if recommendation_res:
406-
inference_recommendation = next(
407-
(
408-
rec
409-
for rec in recommendation_res["InferenceRecommendations"]
410-
if rec["RecommendationId"] == inference_recommendation_id
411-
),
412-
None,
413-
)
414-
if model_res:
415-
instant_recommendation = next(
416-
(
417-
rec
418-
for rec in model_res["DeploymentRecommendation"][
419-
"RealTimeInferenceRecommendations"
420-
]
421-
if rec["RecommendationId"] == inference_recommendation_id
422-
),
423-
None,
424-
)
425-
if inference_recommendation is None and instant_recommendation is None:
426-
raise ValueError("Inference Recommendation id does not exist")
394+
# Search the recommendation from above describe results
395+
(
396+
right_size_recommendation,
397+
model_recommendation,
398+
) = self._get_right_size_and_model_recommendation(
399+
recommendation_res=recommendation_res,
400+
model_res=model_res,
401+
inference_recommendation_id=inference_recommendation_id,
402+
)
427403

428-
# Update params beased on instant recommendation
429-
if instant_recommendation:
404+
# Update params beased on model recommendation
405+
if model_recommendation:
430406
if initial_instance_count is None:
431407
raise ValueError(
432-
"Please specify initial_instance_count with instant recommendation id"
408+
"Please specify initial_instance_count with model recommendation id"
433409
)
434-
self.env.update(instant_recommendation["Environment"])
435-
instance_type = instant_recommendation["InstanceType"]
410+
self.env.update(model_recommendation["Environment"])
411+
instance_type = model_recommendation["InstanceType"]
436412
return (instance_type, initial_instance_count)
437413

438414
# Update params based on default inference recommendation
@@ -443,7 +419,7 @@ def _update_params_for_recommendation_id(
443419
"to override the recommendation."
444420
)
445421
input_config = recommendation_res["InputConfig"]
446-
model_config = inference_recommendation["ModelConfiguration"]
422+
model_config = right_size_recommendation["ModelConfiguration"]
447423
envs = (
448424
model_config["EnvironmentParameters"]
449425
if "EnvironmentParameters" in model_config
@@ -492,8 +468,8 @@ def _update_params_for_recommendation_id(
492468
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
493469
self.image_uri = compilation_res["InferenceImage"]
494470

495-
instance_type = inference_recommendation["EndpointConfiguration"]["InstanceType"]
496-
initial_instance_count = inference_recommendation["EndpointConfiguration"][
471+
instance_type = right_size_recommendation["EndpointConfiguration"]["InstanceType"]
472+
initial_instance_count = right_size_recommendation["EndpointConfiguration"][
497473
"InitialInstanceCount"
498474
]
499475

@@ -563,3 +539,57 @@ def _convert_to_stopping_conditions_json(
563539
threshold.to_json for threshold in model_latency_thresholds
564540
]
565541
return stopping_conditions
542+
543+
def _get_right_size_and_model_recommendation(
544+
self,
545+
model_res=None,
546+
recommendation_res=None,
547+
inference_recommendation_id=None,
548+
):
549+
"""Get recommendation from right size job or model"""
550+
right_size_recommendation, model_recommendation = None, None
551+
if recommendation_res:
552+
right_size_recommendation = self._get_recommendation(
553+
recommendation_list=recommendation_res["InferenceRecommendations"],
554+
inference_recommendation_id=inference_recommendation_id,
555+
)
556+
if model_res:
557+
model_recommendation = self._get_recommendation(
558+
recommendation_list=model_res["DeploymentRecommendation"][
559+
"RealTimeInferenceRecommendations"
560+
],
561+
inference_recommendation_id=inference_recommendation_id,
562+
)
563+
if right_size_recommendation is None and model_recommendation is None:
564+
raise ValueError("Inference Recommendation id is not valid")
565+
566+
return right_size_recommendation, model_recommendation
567+
568+
def _get_recommendation(self, recommendation_list, inference_recommendation_id):
569+
"""Get recommendation based on recommendation id"""
570+
return next(
571+
(
572+
rec
573+
for rec in recommendation_list
574+
if rec["RecommendationId"] == inference_recommendation_id
575+
),
576+
None,
577+
)
578+
579+
def _describe_recommendation_job_and_model(self, sage_client, job_or_model_name):
580+
"""Describe inference recommendation job and model results"""
581+
recommendation_res, model_res = None, None
582+
try:
583+
recommendation_res = sage_client.describe_inference_recommendations_job(
584+
JobName=job_or_model_name
585+
)
586+
except sage_client.exceptions.ResourceNotFound:
587+
pass
588+
try:
589+
model_res = sage_client.describe_model(ModelName=job_or_model_name)
590+
except sage_client.exceptions.ResourceNotFound:
591+
pass
592+
if recommendation_res is None and model_res is None:
593+
raise ValueError("Inference Recommendation id is not valid")
594+
595+
return recommendation_res, model_res

tests/unit/sagemaker/inference_recommender/constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737

3838
INVALID_RECOMMENDATION_ID = "ir-job6ab0ff22"
3939
NOT_EXISTED_RECOMMENDATION_ID = IR_JOB_NAME + "/ad3ec9ee"
40-
NOT_EXISTED_INSTANT_RECOMMENDATION_ID = IR_MODEL_NAME + "/ad3ec9ee"
40+
NOT_EXISTED_MODEL_RECOMMENDATION_ID = IR_MODEL_NAME + "/ad3ec9ee"
4141
RECOMMENDATION_ID = IR_JOB_NAME + "/5bcee92e"
42-
INSTANT_RECOMMENDATION_ID = IR_MODEL_NAME + "/v0KObO5d"
43-
INSTANT_RECOMMENDATION_ENV = {"TS_DEFAULT_WORKERS_PER_MODEL": "4"}
42+
MODEL_RECOMMENDATION_ID = IR_MODEL_NAME + "/v0KObO5d"
43+
MODEL_RECOMMENDATION_ENV = {"TS_DEFAULT_WORKERS_PER_MODEL": "4"}
4444

4545
IR_CONTAINER_CONFIG = {
4646
"Domain": "MACHINE_LEARNING",
@@ -102,9 +102,9 @@
102102
"RecommendationStatus": "COMPLETED",
103103
"RealTimeInferenceRecommendations": [
104104
{
105-
"RecommendationId": INSTANT_RECOMMENDATION_ID,
105+
"RecommendationId": MODEL_RECOMMENDATION_ID,
106106
"InstanceType": "ml.g4dn.2xlarge",
107-
"Environment": INSTANT_RECOMMENDATION_ENV,
107+
"Environment": MODEL_RECOMMENDATION_ENV,
108108
},
109109
{
110110
"RecommendationId": "test-model-name/d248qVYU",

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
DESCRIBE_COMPILATION_JOB_RESPONSE,
2727
DESCRIBE_MODEL_PACKAGE_RESPONSE,
2828
DESCRIBE_MODEL_RESPONSE,
29-
INSTANT_RECOMMENDATION_ENV,
30-
INSTANT_RECOMMENDATION_ID,
29+
MODEL_RECOMMENDATION_ENV,
30+
MODEL_RECOMMENDATION_ID,
3131
INVALID_RECOMMENDATION_ID,
3232
IR_COMPILATION_JOB_NAME,
3333
IR_ENV,
@@ -37,7 +37,7 @@
3737
IR_MODEL_PACKAGE_VERSION_ARN,
3838
IR_COMPILATION_IMAGE,
3939
IR_COMPILATION_MODEL_DATA,
40-
NOT_EXISTED_INSTANT_RECOMMENDATION_ID,
40+
NOT_EXISTED_MODEL_RECOMMENDATION_ID,
4141
RECOMMENDATION_ID,
4242
NOT_EXISTED_RECOMMENDATION_ID,
4343
)
@@ -700,7 +700,7 @@ def mock_describe_compilation_job(CompilationJobName):
700700
assert model.image_uri == IR_COMPILATION_IMAGE
701701

702702

703-
def test_deploy_with_not_existed_recommendation_id(sagemaker_session):
703+
def test_deploy_with_invalid_inference_recommendation_id(sagemaker_session):
704704
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = (
705705
create_inference_recommendations_job_default_with_model_name_and_compilation()
706706
)
@@ -713,41 +713,41 @@ def test_deploy_with_not_existed_recommendation_id(sagemaker_session):
713713

714714
with pytest.raises(
715715
ValueError,
716-
match="Inference Recommendation id does not exist",
716+
match="Inference Recommendation id is not valid",
717717
):
718718
model.deploy(
719719
inference_recommendation_id=NOT_EXISTED_RECOMMENDATION_ID,
720720
)
721721

722722

723-
def test_deploy_with_invalid_instant_recommendation_id(sagemaker_session):
723+
def test_deploy_with_invalid_model_recommendation_id(sagemaker_session):
724724
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = None
725725
sagemaker_session.sagemaker_client.describe_model.side_effect = mock_describe_model
726726

727727
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
728728

729729
with pytest.raises(
730730
ValueError,
731-
match="Inference Recommendation id does not exist",
731+
match="Inference Recommendation id is not valid",
732732
):
733733
model.deploy(
734-
inference_recommendation_id=NOT_EXISTED_INSTANT_RECOMMENDATION_ID,
734+
inference_recommendation_id=NOT_EXISTED_MODEL_RECOMMENDATION_ID,
735735
)
736736

737737

738-
def test_deploy_with_valid_instant_recommendation_id(sagemaker_session):
738+
def test_deploy_with_valid_model_recommendation_id(sagemaker_session):
739739
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = None
740740
sagemaker_session.sagemaker_client.describe_model.side_effect = mock_describe_model
741741

742742
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
743743
model.deploy(
744-
inference_recommendation_id=INSTANT_RECOMMENDATION_ID,
744+
inference_recommendation_id=MODEL_RECOMMENDATION_ID,
745745
initial_instance_count=INSTANCE_COUNT,
746746
)
747747

748748
assert model.model_data == MODEL_DATA
749749
assert model.image_uri == MODEL_IMAGE
750-
assert model.env == INSTANT_RECOMMENDATION_ENV
750+
assert model.env == MODEL_RECOMMENDATION_ENV
751751

752752

753753
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())

0 commit comments

Comments
 (0)