Skip to content

Commit dbd757e

Browse files
jinpengqigwang111
authored andcommitted
Add deployment support for instant recommendations
1 parent 624cac8 commit dbd757e

File tree

3 files changed

+128
-31
lines changed

3 files changed

+128
-31
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,6 @@ def _update_params_for_recommendation_id(
365365
return (instance_type, initial_instance_count)
366366

367367
# Validate non-compatible parameters with recommendation id
368-
if bool(instance_type) != bool(initial_instance_count):
369-
raise ValueError(
370-
"Please either do not specify instance_type and initial_instance_count"
371-
"since they are in recommendation, or specify both of them if you want"
372-
"to override the recommendation."
373-
)
374368
if accelerator_type is not None:
375369
raise ValueError("accelerator_type is not compatible with inference_recommendation_id.")
376370
if async_inference_config is not None:
@@ -390,26 +384,66 @@ def _update_params_for_recommendation_id(
390384
recommendation_job_name = inference_recommendation_id.split("/")[0]
391385

392386
sage_client = self.sagemaker_session.sagemaker_client
393-
recommendation_res = sage_client.describe_inference_recommendations_job(
394-
JobName=recommendation_job_name
395-
)
396-
input_config = recommendation_res["InputConfig"]
397387

398-
recommendation = next(
399-
(
400-
rec
401-
for rec in recommendation_res["InferenceRecommendations"]
402-
if rec["RecommendationId"] == inference_recommendation_id
403-
),
404-
None,
405-
)
388+
# Retrieve model or inference recommendation job details
389+
recommendation_res, model_res = None, None
390+
try:
391+
recommendation_res = sage_client.describe_inference_recommendations_job(
392+
JobName=recommendation_job_name
393+
)
394+
except sage_client.exceptions.ResourceNotFound:
395+
pass
396+
try:
397+
model_res = sage_client.describe_model(ModelName=recommendation_job_name)
398+
except sage_client.exceptions.ResourceNotFound:
399+
pass
400+
if recommendation_res is None and model_res is None:
401+
raise ValueError("Inference Recommendation id is not valid")
406402

407-
if not recommendation:
408-
raise ValueError(
409-
"inference_recommendation_id does not exist in InferenceRecommendations list"
403+
# Search the recommendation from above describe result lists
404+
inference_recommendation, instant_recommendation = None, None
405+
if recommendation_res:
406+
inference_recommendation = next(
407+
(
408+
rec
409+
for rec in recommendation_res["InferenceRecommendations"]
410+
if rec["RecommendationId"] == inference_recommendation_id
411+
),
412+
None,
410413
)
414+
if model_res:
415+
instant_recommendation = next(
416+
(
417+
rec
418+
for rec in model_res["DeploymentRecommendation"][
419+
"RealTimeInferenceRecommendations"
420+
]
421+
if rec["RecommendationId"] == inference_recommendation_id
422+
),
423+
None,
424+
)
425+
if inference_recommendation is None and instant_recommendation is None:
426+
raise ValueError("Inference Recommendation id does not exist")
427+
428+
# Update params beased on instant recommendation
429+
if instant_recommendation:
430+
if initial_instance_count is None:
431+
raise ValueError(
432+
"Please specify initial_instance_count with instant recommendation id"
433+
)
434+
self.env.update(instant_recommendation["Environment"])
435+
instance_type = instant_recommendation["InstanceType"]
436+
return (instance_type, initial_instance_count)
411437

412-
model_config = recommendation["ModelConfiguration"]
438+
# Update params based on default inference recommendation
439+
if bool(instance_type) != bool(initial_instance_count):
440+
raise ValueError(
441+
"Please either do not specify instance_type and initial_instance_count"
442+
"since they are in recommendation, or specify both of them if you want"
443+
"to override the recommendation."
444+
)
445+
input_config = recommendation_res["InputConfig"]
446+
model_config = inference_recommendation["ModelConfiguration"]
413447
envs = (
414448
model_config["EnvironmentParameters"]
415449
if "EnvironmentParameters" in model_config
@@ -458,8 +492,10 @@ def _update_params_for_recommendation_id(
458492
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
459493
self.image_uri = compilation_res["InferenceImage"]
460494

461-
instance_type = recommendation["EndpointConfiguration"]["InstanceType"]
462-
initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"]
495+
instance_type = inference_recommendation["EndpointConfiguration"]["InstanceType"]
496+
initial_instance_count = inference_recommendation["EndpointConfiguration"][
497+
"InitialInstanceCount"
498+
]
463499

464500
return (instance_type, initial_instance_count)
465501

tests/unit/sagemaker/inference_recommender/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838
INVALID_RECOMMENDATION_ID = "ir-job6ab0ff22"
3939
NOT_EXISTED_RECOMMENDATION_ID = IR_JOB_NAME + "/ad3ec9ee"
40+
NOT_EXISTED_INSTANT_RECOMMENDATION_ID = IR_MODEL_NAME + "/ad3ec9ee"
4041
RECOMMENDATION_ID = IR_JOB_NAME + "/5bcee92e"
42+
INSTANT_RECOMMENDATION_ID = IR_MODEL_NAME + "/v0KObO5d"
43+
INSTANT_RECOMMENDATION_ENV = {"TS_DEFAULT_WORKERS_PER_MODEL": "4"}
4144

4245
IR_CONTAINER_CONFIG = {
4346
"Domain": "MACHINE_LEARNING",
@@ -95,6 +98,21 @@
9598
"Image": IR_IMAGE,
9699
"ModelDataUrl": IR_MODEL_DATA,
97100
},
101+
"DeploymentRecommendation": {
102+
"RecommendationStatus": "COMPLETED",
103+
"RealTimeInferenceRecommendations": [
104+
{
105+
"RecommendationId": INSTANT_RECOMMENDATION_ID,
106+
"InstanceType": "ml.g4dn.2xlarge",
107+
"Environment": INSTANT_RECOMMENDATION_ENV,
108+
},
109+
{
110+
"RecommendationId": "test-model-name/d248qVYU",
111+
"InstanceType": "ml.c6i.large",
112+
"Environment": {},
113+
},
114+
],
115+
},
98116
}
99117

100118
DESCRIBE_MODEL_PACKAGE_RESPONSE = {

tests/unit/sagemaker/model/test_deploy.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
DESCRIBE_COMPILATION_JOB_RESPONSE,
2727
DESCRIBE_MODEL_PACKAGE_RESPONSE,
2828
DESCRIBE_MODEL_RESPONSE,
29+
INSTANT_RECOMMENDATION_ENV,
30+
INSTANT_RECOMMENDATION_ID,
2931
INVALID_RECOMMENDATION_ID,
3032
IR_COMPILATION_JOB_NAME,
3133
IR_ENV,
@@ -35,6 +37,7 @@
3537
IR_MODEL_PACKAGE_VERSION_ARN,
3638
IR_COMPILATION_IMAGE,
3739
IR_COMPILATION_MODEL_DATA,
40+
NOT_EXISTED_INSTANT_RECOMMENDATION_ID,
3841
RECOMMENDATION_ID,
3942
NOT_EXISTED_RECOMMENDATION_ID,
4043
)
@@ -543,6 +546,11 @@ def test_deploy_wrong_async_inferenc_config(sagemaker_session):
543546

544547

545548
def test_deploy_ir_with_incompatible_parameters(sagemaker_session):
549+
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = (
550+
create_inference_recommendations_job_default_with_model_package_arn()
551+
)
552+
sagemaker_session.sagemaker_client.describe_model.return_value = None
553+
546554
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
547555

548556
with pytest.raises(
@@ -553,7 +561,7 @@ def test_deploy_ir_with_incompatible_parameters(sagemaker_session):
553561
):
554562
model.deploy(
555563
instance_type=INSTANCE_TYPE,
556-
inference_recommendation_id=INFERENCE_RECOMMENDATION_ID,
564+
inference_recommendation_id=RECOMMENDATION_ID,
557565
)
558566

559567
with pytest.raises(
@@ -564,7 +572,7 @@ def test_deploy_ir_with_incompatible_parameters(sagemaker_session):
564572
):
565573
model.deploy(
566574
initial_instance_count=INSTANCE_COUNT,
567-
inference_recommendation_id=INFERENCE_RECOMMENDATION_ID,
575+
inference_recommendation_id=RECOMMENDATION_ID,
568576
)
569577

570578
with pytest.raises(
@@ -615,6 +623,7 @@ def test_deploy_with_recommendation_id_with_model_pkg_arn(sagemaker_session):
615623
sagemaker_session.sagemaker_client.describe_model_package.side_effect = (
616624
mock_describe_model_package
617625
)
626+
sagemaker_session.sagemaker_client.describe_model.return_value = None
618627

619628
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
620629

@@ -627,11 +636,12 @@ def test_deploy_with_recommendation_id_with_model_pkg_arn(sagemaker_session):
627636
assert model.env == IR_ENV
628637

629638

630-
def test_deploy_with_recommendation_id_with_model_name(sagemaker_session):
631-
def mock_describe_model(ModelName):
632-
if ModelName == IR_MODEL_NAME:
633-
return DESCRIBE_MODEL_RESPONSE
639+
def mock_describe_model(ModelName):
640+
if ModelName == IR_MODEL_NAME:
641+
return DESCRIBE_MODEL_RESPONSE
642+
634643

644+
def test_deploy_with_recommendation_id_with_model_name(sagemaker_session):
635645
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = (
636646
create_inference_recommendations_job_default_with_model_name()
637647
)
@@ -655,6 +665,7 @@ def test_deploy_with_recommendation_id_with_model_pkg_arn_and_compilation(sagema
655665
sagemaker_session.sagemaker_client.describe_model_package.side_effect = (
656666
mock_describe_model_package
657667
)
668+
sagemaker_session.sagemaker_client.describe_model.return_value = None
658669

659670
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
660671

@@ -677,6 +688,7 @@ def mock_describe_compilation_job(CompilationJobName):
677688
sagemaker_session.sagemaker_client.describe_compilation_job.side_effect = (
678689
mock_describe_compilation_job
679690
)
691+
sagemaker_session.sagemaker_client.describe_model.side_effect = mock_describe_model
680692

681693
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
682694

@@ -695,18 +707,49 @@ def test_deploy_with_not_existed_recommendation_id(sagemaker_session):
695707
sagemaker_session.sagemaker_client.describe_compilation_job.return_value = (
696708
DESCRIBE_COMPILATION_JOB_RESPONSE
697709
)
710+
sagemaker_session.sagemaker_client.describe_model.return_value = None
698711

699712
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
700713

701714
with pytest.raises(
702715
ValueError,
703-
match="inference_recommendation_id does not exist in InferenceRecommendations list",
716+
match="Inference Recommendation id does not exist",
704717
):
705718
model.deploy(
706719
inference_recommendation_id=NOT_EXISTED_RECOMMENDATION_ID,
707720
)
708721

709722

723+
def test_deploy_with_invalid_instant_recommendation_id(sagemaker_session):
724+
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = None
725+
sagemaker_session.sagemaker_client.describe_model.side_effect = mock_describe_model
726+
727+
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
728+
729+
with pytest.raises(
730+
ValueError,
731+
match="Inference Recommendation id does not exist",
732+
):
733+
model.deploy(
734+
inference_recommendation_id=NOT_EXISTED_INSTANT_RECOMMENDATION_ID,
735+
)
736+
737+
738+
def test_deploy_with_valid_instant_recommendation_id(sagemaker_session):
739+
sagemaker_session.sagemaker_client.describe_inference_recommendations_job.return_value = None
740+
sagemaker_session.sagemaker_client.describe_model.side_effect = mock_describe_model
741+
742+
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE)
743+
model.deploy(
744+
inference_recommendation_id=INSTANT_RECOMMENDATION_ID,
745+
initial_instance_count=INSTANCE_COUNT,
746+
)
747+
748+
assert model.model_data == MODEL_DATA
749+
assert model.image_uri == MODEL_IMAGE
750+
assert model.env == INSTANT_RECOMMENDATION_ENV
751+
752+
710753
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
711754
@patch("sagemaker.predictor.Predictor._get_endpoint_config_name", Mock())
712755
@patch("sagemaker.predictor.Predictor._get_model_names", Mock())

0 commit comments

Comments
 (0)