Skip to content

feat: Add support for Deployment Recommendation ID in model.deploy() #3897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
149 changes: 121 additions & 28 deletions src/sagemaker/inference_recommender/inference_recommender_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

LOGGER = logging.getLogger("sagemaker")

DEPLOYMENT_RECOMMENDATION_TAG = "PythonSDK-DeploymentRecommendation"

RIGHT_SIZE_TAG = "PythonSDK-RightSize"


class Phase:
"""Used to store phases of a traffic pattern to perform endpoint load testing.
Expand Down Expand Up @@ -218,6 +222,7 @@ def _update_params(
explainer_config = kwargs["explainer_config"]
inference_recommendation_id = kwargs["inference_recommendation_id"]
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
tags = kwargs["tags"]
if inference_recommendation_id is not None:
inference_recommendation = self._update_params_for_recommendation_id(
instance_type=instance_type,
Expand All @@ -237,7 +242,11 @@ def _update_params(
async_inference_config,
explainer_config,
)
return inference_recommendation or (instance_type, initial_instance_count)

if inference_recommendation:
tags = self._add_client_type_tag(tags, inference_recommendation[2])
return (inference_recommendation[0], inference_recommendation[1], tags)
return (instance_type, initial_instance_count, tags)

def _update_params_for_right_size(
self,
Expand Down Expand Up @@ -301,7 +310,7 @@ def _update_params_for_right_size(
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
"InitialInstanceCount"
]
return (instance_type, initial_instance_count)
return (instance_type, initial_instance_count, RIGHT_SIZE_TAG)

def _update_params_for_recommendation_id(
self,
Expand Down Expand Up @@ -365,12 +374,6 @@ def _update_params_for_recommendation_id(
return (instance_type, initial_instance_count)

# Validate non-compatible parameters with recommendation id
if bool(instance_type) != bool(initial_instance_count):
raise ValueError(
"Please either do not specify instance_type and initial_instance_count"
"since they are in recommendation, or specify both of them if you want"
"to override the recommendation."
)
if accelerator_type is not None:
raise ValueError("accelerator_type is not compatible with inference_recommendation_id.")
if async_inference_config is not None:
Expand All @@ -386,30 +389,38 @@ def _update_params_for_recommendation_id(

# Validate recommendation id
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
raise ValueError("Inference Recommendation id is not valid")
recommendation_job_name = inference_recommendation_id.split("/")[0]
raise ValueError("inference_recommendation_id is not valid")
job_or_model_name = inference_recommendation_id.split("/")[0]

sage_client = self.sagemaker_session.sagemaker_client
recommendation_res = sage_client.describe_inference_recommendations_job(
JobName=recommendation_job_name
# Get recommendation from right size job and model
(
right_size_recommendation,
model_recommendation,
right_size_job_res,
) = self._get_recommendation(
sage_client=sage_client,
job_or_model_name=job_or_model_name,
inference_recommendation_id=inference_recommendation_id,
)
input_config = recommendation_res["InputConfig"]

recommendation = next(
(
rec
for rec in recommendation_res["InferenceRecommendations"]
if rec["RecommendationId"] == inference_recommendation_id
),
None,
)
# Update params beased on model recommendation
if model_recommendation:
if initial_instance_count is None:
raise ValueError("Must specify model recommendation id and instance count.")
self.env.update(model_recommendation["Environment"])
instance_type = model_recommendation["InstanceType"]
return (instance_type, initial_instance_count, DEPLOYMENT_RECOMMENDATION_TAG)

if not recommendation:
# Update params based on default inference recommendation
if bool(instance_type) != bool(initial_instance_count):
raise ValueError(
"inference_recommendation_id does not exist in InferenceRecommendations list"
"instance_type and initial_instance_count are mutually exclusive with"
"recommendation id since they are in recommendation."
"Please specify both of them if you want to override the recommendation."
)

model_config = recommendation["ModelConfiguration"]
input_config = right_size_job_res["InputConfig"]
model_config = right_size_recommendation["ModelConfiguration"]
envs = (
model_config["EnvironmentParameters"]
if "EnvironmentParameters" in model_config
Expand Down Expand Up @@ -458,10 +469,12 @@ def _update_params_for_recommendation_id(
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
self.image_uri = compilation_res["InferenceImage"]

instance_type = recommendation["EndpointConfiguration"]["InstanceType"]
initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"]
instance_type = right_size_recommendation["EndpointConfiguration"]["InstanceType"]
initial_instance_count = right_size_recommendation["EndpointConfiguration"][
"InitialInstanceCount"
]

return (instance_type, initial_instance_count)
return (instance_type, initial_instance_count, RIGHT_SIZE_TAG)

def _convert_to_endpoint_configurations_json(
self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
Expand Down Expand Up @@ -527,3 +540,83 @@ def _convert_to_stopping_conditions_json(
threshold.to_json for threshold in model_latency_thresholds
]
return stopping_conditions

def _get_recommendation(self, sage_client, job_or_model_name, inference_recommendation_id):
"""Get recommendation from right size job and model"""
right_size_recommendation, model_recommendation, right_size_job_res = None, None, None
right_size_recommendation, right_size_job_res = self._get_right_size_recommendation(
sage_client=sage_client,
job_or_model_name=job_or_model_name,
inference_recommendation_id=inference_recommendation_id,
)
if right_size_recommendation is None:
model_recommendation = self._get_model_recommendation(
sage_client=sage_client,
job_or_model_name=job_or_model_name,
inference_recommendation_id=inference_recommendation_id,
)
if model_recommendation is None:
raise ValueError("inference_recommendation_id is not valid")

return right_size_recommendation, model_recommendation, right_size_job_res

def _get_right_size_recommendation(
self,
sage_client,
job_or_model_name,
inference_recommendation_id,
):
"""Get recommendation from right size job"""
right_size_recommendation, right_size_job_res = None, None
try:
right_size_job_res = sage_client.describe_inference_recommendations_job(
JobName=job_or_model_name
)
if right_size_job_res:
right_size_recommendation = self._search_recommendation(
recommendation_list=right_size_job_res["InferenceRecommendations"],
inference_recommendation_id=inference_recommendation_id,
)
except sage_client.exceptions.ResourceNotFound:
pass

return right_size_recommendation, right_size_job_res

def _get_model_recommendation(
self,
sage_client,
job_or_model_name,
inference_recommendation_id,
):
"""Get recommendation from model"""
model_recommendation = None
try:
model_res = sage_client.describe_model(ModelName=job_or_model_name)
if model_res:
model_recommendation = self._search_recommendation(
recommendation_list=model_res["DeploymentRecommendation"][
"RealTimeInferenceRecommendations"
],
inference_recommendation_id=inference_recommendation_id,
)
except sage_client.exceptions.ResourceNotFound:
pass

return model_recommendation

def _search_recommendation(self, recommendation_list, inference_recommendation_id):
"""Search recommendation based on recommendation id"""
return next(
(
rec
for rec in recommendation_list
if rec["RecommendationId"] == inference_recommendation_id
),
None,
)

def _add_client_type_tag(self, tags, client_type):
"""Tagging for Inference Recommender and Deployment Recommendations"""
client_type_tag = {"Key": "ClientType", "Value": client_type}
tags = tags.append(client_type_tag) if tags else [client_type_tag]
return tags
14 changes: 11 additions & 3 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ def deploy(
inference_recommendation_id (str): The recommendation id which specifies the
recommendation you picked from inference recommendation job results and
would like to deploy the model and endpoint with recommended parameters.
This can also be a recommendation id returned from ``DescribeModel`` contained in
a list of ``RealtimeInferenceRecommendations`` within ``DeploymentRecommendation``
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. Default: None.
Raises:
Expand Down Expand Up @@ -1251,7 +1253,7 @@ def deploy(
inference_recommendation_id is not None
or self.inference_recommender_job_results is not None
):
instance_type, initial_instance_count = self._update_params(
instance_type, initial_instance_count, tags = self._update_params(
instance_type=instance_type,
initial_instance_count=initial_instance_count,
accelerator_type=accelerator_type,
Expand All @@ -1260,6 +1262,7 @@ def deploy(
explainer_config=explainer_config,
inference_recommendation_id=inference_recommendation_id,
inference_recommender_job_results=self.inference_recommender_job_results,
tags=tags,
)

is_async = async_inference_config is not None
Expand Down Expand Up @@ -1754,10 +1757,10 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar

Args:
args: Positional arguments coming from the caller. This class does not require
any so they are ignored.
any but will look for tags in the 3rd parameter.

kwargs: Keyword arguments coming from the caller. This class does not require
any so they are ignored.
any but will search for tags if not in args.
"""
if self.algorithm_arn:
# When ModelPackage is created using an algorithm_arn we need to first
Expand All @@ -1779,12 +1782,17 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
self._ensure_base_name_if_needed(model_package_name.split("/")[-1])
self._set_model_name_if_needed()

# If tags are in args, it must be the 3rd param
# If not, then check kwargs and set to either tags or None
tags = args[2] if len(args) >= 3 else kwargs.get("tags")

self.sagemaker_session.create_model(
self.name,
self.role,
container_def,
vpc_config=self.vpc_config,
enable_network_isolation=self.enable_network_isolation(),
tags=tags,
)

def _ensure_base_name_if_needed(self, base_name):
Expand Down
77 changes: 77 additions & 0 deletions tests/integ/test_inference_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,30 @@ def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_
sagemaker_session.delete_model(ModelName=model.name)


@pytest.fixture(scope="module")
def created_base_model(sagemaker_session, cpu_instance_type):
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
region = sagemaker_session._region_name
image_uri = image_uris.retrieve(
framework="sklearn", region=region, version="1.0-1", image_scope="inference"
)

iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

model = Model(
model_data=model_data,
role=role_arn,
entry_point=IR_SKLEARN_ENTRY_POINT,
image_uri=image_uri,
sagemaker_session=sagemaker_session,
)

model.create(instance_type=cpu_instance_type)

return model


@pytest.mark.slow_test
def test_default_right_size_and_deploy_registered_model_sklearn(
default_right_sized_model, sagemaker_session
Expand Down Expand Up @@ -453,3 +477,56 @@ def test_deploy_inference_recommendation_id_with_registered_model_sklearn(
)
predictor.delete_model()
predictor.delete_endpoint()


@pytest.mark.slow_test
def test_deploy_deployment_recommendation_id_with_model(created_base_model, sagemaker_session):
with timeout(minutes=20):
try:
deployment_recommendation = poll_for_deployment_recommendation(
created_base_model, sagemaker_session
)

assert deployment_recommendation is not None

real_time_recommendations = deployment_recommendation.get(
"RealTimeInferenceRecommendations"
)
recommendation_id = real_time_recommendations[0].get("RecommendationId")

endpoint_name = unique_name_from_base("test-rec-id-deployment-default-sklearn")
created_base_model.predictor_cls = SKLearnPredictor
predictor = created_base_model.deploy(
inference_recommendation_id=recommendation_id,
initial_instance_count=1,
endpoint_name=endpoint_name,
)

payload = pd.read_csv(IR_SKLEARN_DATA, header=None)

inference = predictor.predict(payload)
assert inference is not None
assert 26 == len(inference)
finally:
predictor.delete_model()
predictor.delete_endpoint()


def poll_for_deployment_recommendation(created_base_model, sagemaker_session):
with timeout(minutes=1):
try:
completed = False
while not completed:
describe_model_response = sagemaker_session.sagemaker_client.describe_model(
ModelName=created_base_model.name
)
deployment_recommendation = describe_model_response.get("DeploymentRecommendation")

completed = (
deployment_recommendation is not None
and "COMPLETED" == deployment_recommendation.get("RecommendationStatus")
)
return deployment_recommendation
except Exception as e:
created_base_model.delete_model()
raise e
Loading