Skip to content

Commit 12b19ee

Browse files
gwang111jinpengqi
andauthored
feat: Add support for Deployment Recommendation ID in model.deploy(). No tagging support (#3920)
Co-authored-by: jinpengqi <[email protected]> Co-authored-by: Gary Wang 😤 <[email protected]>
1 parent 1be4460 commit 12b19ee

File tree

6 files changed

+422
-54
lines changed

6 files changed

+422
-54
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 109 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,12 @@ def _update_params(
237237
async_inference_config,
238238
explainer_config,
239239
)
240-
return inference_recommendation or (instance_type, initial_instance_count)
240+
241+
return (
242+
inference_recommendation
243+
if inference_recommendation
244+
else (instance_type, initial_instance_count)
245+
)
241246

242247
def _update_params_for_right_size(
243248
self,
@@ -365,12 +370,6 @@ def _update_params_for_recommendation_id(
365370
return (instance_type, initial_instance_count)
366371

367372
# Validate non-compatible parameters with recommendation id
368-
if bool(instance_type) != bool(initial_instance_count):
369-
raise ValueError(
370-
"Please either do not specify instance_type and initial_instance_count"
371-
"since they are in recommendation, or specify both of them if you want"
372-
"to override the recommendation."
373-
)
374373
if accelerator_type is not None:
375374
raise ValueError("accelerator_type is not compatible with inference_recommendation_id.")
376375
if async_inference_config is not None:
@@ -386,30 +385,38 @@ def _update_params_for_recommendation_id(
386385

387386
# Validate recommendation id
388387
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
389-
raise ValueError("Inference Recommendation id is not valid")
390-
recommendation_job_name = inference_recommendation_id.split("/")[0]
388+
raise ValueError("inference_recommendation_id is not valid")
389+
job_or_model_name = inference_recommendation_id.split("/")[0]
391390

392391
sage_client = self.sagemaker_session.sagemaker_client
393-
recommendation_res = sage_client.describe_inference_recommendations_job(
394-
JobName=recommendation_job_name
392+
# Get recommendation from right size job and model
393+
(
394+
right_size_recommendation,
395+
model_recommendation,
396+
right_size_job_res,
397+
) = self._get_recommendation(
398+
sage_client=sage_client,
399+
job_or_model_name=job_or_model_name,
400+
inference_recommendation_id=inference_recommendation_id,
395401
)
396-
input_config = recommendation_res["InputConfig"]
397402

398-
recommendation = next(
399-
(
400-
rec
401-
for rec in recommendation_res["InferenceRecommendations"]
402-
if rec["RecommendationId"] == inference_recommendation_id
403-
),
404-
None,
405-
)
403+
# Update params beased on model recommendation
404+
if model_recommendation:
405+
if initial_instance_count is None:
406+
raise ValueError("Must specify model recommendation id and instance count.")
407+
self.env.update(model_recommendation["Environment"])
408+
instance_type = model_recommendation["InstanceType"]
409+
return (instance_type, initial_instance_count)
406410

407-
if not recommendation:
411+
# Update params based on default inference recommendation
412+
if bool(instance_type) != bool(initial_instance_count):
408413
raise ValueError(
409-
"inference_recommendation_id does not exist in InferenceRecommendations list"
414+
"instance_type and initial_instance_count are mutually exclusive with"
415+
"recommendation id since they are in recommendation."
416+
"Please specify both of them if you want to override the recommendation."
410417
)
411-
412-
model_config = recommendation["ModelConfiguration"]
418+
input_config = right_size_job_res["InputConfig"]
419+
model_config = right_size_recommendation["ModelConfiguration"]
413420
envs = (
414421
model_config["EnvironmentParameters"]
415422
if "EnvironmentParameters" in model_config
@@ -458,8 +465,10 @@ def _update_params_for_recommendation_id(
458465
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
459466
self.image_uri = compilation_res["InferenceImage"]
460467

461-
instance_type = recommendation["EndpointConfiguration"]["InstanceType"]
462-
initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"]
468+
instance_type = right_size_recommendation["EndpointConfiguration"]["InstanceType"]
469+
initial_instance_count = right_size_recommendation["EndpointConfiguration"][
470+
"InitialInstanceCount"
471+
]
463472

464473
return (instance_type, initial_instance_count)
465474

@@ -527,3 +536,77 @@ def _convert_to_stopping_conditions_json(
527536
threshold.to_json for threshold in model_latency_thresholds
528537
]
529538
return stopping_conditions
539+
540+
def _get_recommendation(self, sage_client, job_or_model_name, inference_recommendation_id):
541+
"""Get recommendation from right size job and model"""
542+
right_size_recommendation, model_recommendation, right_size_job_res = None, None, None
543+
right_size_recommendation, right_size_job_res = self._get_right_size_recommendation(
544+
sage_client=sage_client,
545+
job_or_model_name=job_or_model_name,
546+
inference_recommendation_id=inference_recommendation_id,
547+
)
548+
if right_size_recommendation is None:
549+
model_recommendation = self._get_model_recommendation(
550+
sage_client=sage_client,
551+
job_or_model_name=job_or_model_name,
552+
inference_recommendation_id=inference_recommendation_id,
553+
)
554+
if model_recommendation is None:
555+
raise ValueError("inference_recommendation_id is not valid")
556+
557+
return right_size_recommendation, model_recommendation, right_size_job_res
558+
559+
def _get_right_size_recommendation(
560+
self,
561+
sage_client,
562+
job_or_model_name,
563+
inference_recommendation_id,
564+
):
565+
"""Get recommendation from right size job"""
566+
right_size_recommendation, right_size_job_res = None, None
567+
try:
568+
right_size_job_res = sage_client.describe_inference_recommendations_job(
569+
JobName=job_or_model_name
570+
)
571+
if right_size_job_res:
572+
right_size_recommendation = self._search_recommendation(
573+
recommendation_list=right_size_job_res["InferenceRecommendations"],
574+
inference_recommendation_id=inference_recommendation_id,
575+
)
576+
except sage_client.exceptions.ResourceNotFound:
577+
pass
578+
579+
return right_size_recommendation, right_size_job_res
580+
581+
def _get_model_recommendation(
582+
self,
583+
sage_client,
584+
job_or_model_name,
585+
inference_recommendation_id,
586+
):
587+
"""Get recommendation from model"""
588+
model_recommendation = None
589+
try:
590+
model_res = sage_client.describe_model(ModelName=job_or_model_name)
591+
if model_res:
592+
model_recommendation = self._search_recommendation(
593+
recommendation_list=model_res["DeploymentRecommendation"][
594+
"RealTimeInferenceRecommendations"
595+
],
596+
inference_recommendation_id=inference_recommendation_id,
597+
)
598+
except sage_client.exceptions.ResourceNotFound:
599+
pass
600+
601+
return model_recommendation
602+
603+
def _search_recommendation(self, recommendation_list, inference_recommendation_id):
604+
"""Search recommendation based on recommendation id"""
605+
return next(
606+
(
607+
rec
608+
for rec in recommendation_list
609+
if rec["RecommendationId"] == inference_recommendation_id
610+
),
611+
None,
612+
)

src/sagemaker/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,8 @@ def deploy(
12101210
inference_recommendation_id (str): The recommendation id which specifies the
12111211
recommendation you picked from inference recommendation job results and
12121212
would like to deploy the model and endpoint with recommended parameters.
1213+
This can also be a recommendation id returned from ``DescribeModel`` contained in
1214+
a list of ``RealtimeInferenceRecommendations`` within ``DeploymentRecommendation``
12131215
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
12141216
configuration for use with Amazon SageMaker Clarify. Default: None.
12151217
Raises:

tests/integ/test_inference_recommender.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,30 @@ def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_
303303
sagemaker_session.delete_model(ModelName=model.name)
304304

305305

306+
@pytest.fixture(scope="module")
307+
def created_base_model(sagemaker_session, cpu_instance_type):
308+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
309+
region = sagemaker_session._region_name
310+
image_uri = image_uris.retrieve(
311+
framework="sklearn", region=region, version="1.0-1", image_scope="inference"
312+
)
313+
314+
iam_client = sagemaker_session.boto_session.client("iam")
315+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
316+
317+
model = Model(
318+
model_data=model_data,
319+
role=role_arn,
320+
entry_point=IR_SKLEARN_ENTRY_POINT,
321+
image_uri=image_uri,
322+
sagemaker_session=sagemaker_session,
323+
)
324+
325+
model.create(instance_type=cpu_instance_type)
326+
327+
return model
328+
329+
306330
@pytest.mark.slow_test
307331
def test_default_right_size_and_deploy_registered_model_sklearn(
308332
default_right_sized_model, sagemaker_session
@@ -453,3 +477,56 @@ def test_deploy_inference_recommendation_id_with_registered_model_sklearn(
453477
)
454478
predictor.delete_model()
455479
predictor.delete_endpoint()
480+
481+
482+
@pytest.mark.slow_test
483+
def test_deploy_deployment_recommendation_id_with_model(created_base_model, sagemaker_session):
484+
with timeout(minutes=20):
485+
try:
486+
deployment_recommendation = poll_for_deployment_recommendation(
487+
created_base_model, sagemaker_session
488+
)
489+
490+
assert deployment_recommendation is not None
491+
492+
real_time_recommendations = deployment_recommendation.get(
493+
"RealTimeInferenceRecommendations"
494+
)
495+
recommendation_id = real_time_recommendations[0].get("RecommendationId")
496+
497+
endpoint_name = unique_name_from_base("test-rec-id-deployment-default-sklearn")
498+
created_base_model.predictor_cls = SKLearnPredictor
499+
predictor = created_base_model.deploy(
500+
inference_recommendation_id=recommendation_id,
501+
initial_instance_count=1,
502+
endpoint_name=endpoint_name,
503+
)
504+
505+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
506+
507+
inference = predictor.predict(payload)
508+
assert inference is not None
509+
assert 26 == len(inference)
510+
finally:
511+
predictor.delete_model()
512+
predictor.delete_endpoint()
513+
514+
515+
def poll_for_deployment_recommendation(created_base_model, sagemaker_session):
516+
with timeout(minutes=1):
517+
try:
518+
completed = False
519+
while not completed:
520+
describe_model_response = sagemaker_session.sagemaker_client.describe_model(
521+
ModelName=created_base_model.name
522+
)
523+
deployment_recommendation = describe_model_response.get("DeploymentRecommendation")
524+
525+
completed = (
526+
deployment_recommendation is not None
527+
and "COMPLETED" == deployment_recommendation.get("RecommendationStatus")
528+
)
529+
return deployment_recommendation
530+
except Exception as e:
531+
created_base_model.delete_model()
532+
raise e

tests/unit/sagemaker/inference_recommender/constants.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838
INVALID_RECOMMENDATION_ID = "ir-job6ab0ff22"
3939
NOT_EXISTED_RECOMMENDATION_ID = IR_JOB_NAME + "/ad3ec9ee"
40+
NOT_EXISTED_MODEL_RECOMMENDATION_ID = IR_MODEL_NAME + "/ad3ec9ee"
4041
RECOMMENDATION_ID = IR_JOB_NAME + "/5bcee92e"
42+
MODEL_RECOMMENDATION_ID = IR_MODEL_NAME + "/v0KObO5d"
43+
MODEL_RECOMMENDATION_ENV = {"TS_DEFAULT_WORKERS_PER_MODEL": "4"}
4144

4245
IR_CONTAINER_CONFIG = {
4346
"Domain": "MACHINE_LEARNING",
@@ -95,6 +98,21 @@
9598
"Image": IR_IMAGE,
9699
"ModelDataUrl": IR_MODEL_DATA,
97100
},
101+
"DeploymentRecommendation": {
102+
"RecommendationStatus": "COMPLETED",
103+
"RealTimeInferenceRecommendations": [
104+
{
105+
"RecommendationId": MODEL_RECOMMENDATION_ID,
106+
"InstanceType": "ml.g4dn.2xlarge",
107+
"Environment": MODEL_RECOMMENDATION_ENV,
108+
},
109+
{
110+
"RecommendationId": "test-model-name/d248qVYU",
111+
"InstanceType": "ml.c6i.large",
112+
"Environment": {},
113+
},
114+
],
115+
},
98116
}
99117

100118
DESCRIBE_MODEL_PACKAGE_RESPONSE = {
@@ -134,3 +152,30 @@
134152
"ModelArtifacts": {"S3ModelArtifacts": IR_COMPILATION_MODEL_DATA},
135153
"InferenceImage": IR_COMPILATION_IMAGE,
136154
}
155+
156+
IR_CONTAINER_DEF = {
157+
"Image": IR_IMAGE,
158+
"Environment": IR_ENV,
159+
"ModelDataUrl": IR_MODEL_DATA,
160+
}
161+
162+
DEPLOYMENT_RECOMMENDATION_CONTAINER_DEF = {
163+
"Image": IR_IMAGE,
164+
"Environment": MODEL_RECOMMENDATION_ENV,
165+
"ModelDataUrl": IR_MODEL_DATA,
166+
}
167+
168+
IR_COMPILATION_CONTAINER_DEF = {
169+
"Image": IR_COMPILATION_IMAGE,
170+
"Environment": {},
171+
"ModelDataUrl": IR_COMPILATION_MODEL_DATA,
172+
}
173+
174+
IR_MODEL_PACKAGE_CONTAINER_DEF = {
175+
"ModelPackageName": IR_MODEL_PACKAGE_VERSION_ARN,
176+
"Environment": IR_ENV,
177+
}
178+
179+
IR_COMPILATION_MODEL_PACKAGE_CONTAINER_DEF = {
180+
"ModelPackageName": IR_MODEL_PACKAGE_VERSION_ARN,
181+
}

0 commit comments

Comments
 (0)