Skip to content

Commit 8d2c16b

Browse files
jinpengqiJinpeng Qi
andauthored
feature: Inference recommendation id deployment support (#3631)
Co-authored-by: Jinpeng Qi <[email protected]>
1 parent ba30a1f commit 8d2c16b

File tree

10 files changed

+678
-30
lines changed

10 files changed

+678
-30
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,7 @@ def deploy(
13421342
volume_size=None,
13431343
model_data_download_timeout=None,
13441344
container_startup_health_check_timeout=None,
1345+
inference_recommendation_id=None,
13451346
**kwargs,
13461347
):
13471348
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -1419,6 +1420,9 @@ def deploy(
14191420
inference container to pass health check by SageMaker Hosting. For more information
14201421
about health check see:
14211422
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
1423+
inference_recommendation_id (str): The recommendation id which specifies the
1424+
recommendation you picked from inference recommendation job results and
1425+
would like to deploy the model and endpoint with recommended parameters.
14221426
**kwargs: Passed to invocation of ``create_model()``.
14231427
Implementations may customize ``create_model()`` to accept
14241428
``**kwargs`` to customize model creation during deploy.
@@ -1483,6 +1487,7 @@ def deploy(
14831487
volume_size=volume_size,
14841488
model_data_download_timeout=model_data_download_timeout,
14851489
container_startup_health_check_timeout=container_startup_health_check_timeout,
1490+
inference_recommendation_id=inference_recommendation_id,
14861491
)
14871492

14881493
def register(

src/sagemaker/huggingface/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def deploy(
209209
volume_size=None,
210210
model_data_download_timeout=None,
211211
container_startup_health_check_timeout=None,
212+
inference_recommendation_id=None,
212213
**kwargs,
213214
):
214215
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -282,6 +283,9 @@ def deploy(
282283
inference container to pass health check by SageMaker Hosting. For more information
283284
about health check see:
284285
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
286+
inference_recommendation_id (str): The recommendation id which specifies the
287+
recommendation you picked from inference recommendation job results and
288+
would like to deploy the model and endpoint with recommended parameters.
285289
Raises:
286290
ValueError: If arguments combination check failed in these circumstances:
287291
- If no role is specified or
@@ -317,6 +321,7 @@ def deploy(
317321
volume_size=volume_size,
318322
model_data_download_timeout=model_data_download_timeout,
319323
container_startup_health_check_timeout=container_startup_health_check_timeout,
324+
inference_recommendation_id=inference_recommendation_id,
320325
)
321326

322327
def register(

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 198 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
import re
1718

1819
from typing import List, Dict, Optional
19-
2020
import sagemaker
21-
2221
from sagemaker.parameter import CategoricalParameter
2322

2423
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -101,13 +100,15 @@ def right_size(
101100
'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4'])
102101
}]
103102
104-
phases (list[Phase]): Specifies the criteria for increasing load
105-
during endpoint load tests. (default: None).
106-
traffic_type (str): Specifies the traffic type that matches the phases. (default: None).
107-
max_invocations (str): defines invocation limit for endpoint load tests (default: None).
108-
model_latency_thresholds (list[ModelLatencyThreshold]): defines the response latency
109-
thresholds for endpoint load tests (default: None).
110-
max_tests (int): restricts how many endpoints are allowed to be
103+
phases (list[Phase]): Shape of the traffic pattern to use in the load test
104+
(default: None).
105+
traffic_type (str): Specifies the traffic pattern type. Currently only supports
106+
one type 'PHASES' (default: None).
107+
max_invocations (str): defines the minimum invocations per minute for the endpoint
108+
to support (default: None).
109+
model_latency_thresholds (list[ModelLatencyThreshold]): defines the maximum response
110+
latency for endpoints to support (default: None).
111+
max_tests (int): restricts how many endpoints in total are allowed to be
111112
spun up for this job (default: None).
112113
max_parallel_tests (int): restricts how many concurrent endpoints
113114
this job is allowed to spin up (default: None).
@@ -122,7 +123,7 @@ def right_size(
122123
raise ValueError("right_size() is currently only supported with a registered model")
123124

124125
if not framework and self._framework():
125-
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework, framework)
126+
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework(), framework)
126127

127128
framework_version = self._get_framework_version()
128129

@@ -176,7 +177,38 @@ def right_size(
176177

177178
return self
178179

179-
def _check_inference_recommender_args(
180+
def _update_params(
181+
self,
182+
**kwargs,
183+
):
184+
"""Check and update params based on inference recommendation id or right size case"""
185+
instance_type = kwargs["instance_type"]
186+
initial_instance_count = kwargs["initial_instance_count"]
187+
accelerator_type = kwargs["accelerator_type"]
188+
async_inference_config = kwargs["async_inference_config"]
189+
serverless_inference_config = kwargs["serverless_inference_config"]
190+
inference_recommendation_id = kwargs["inference_recommendation_id"]
191+
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
192+
if inference_recommendation_id is not None:
193+
inference_recommendation = self._update_params_for_recommendation_id(
194+
instance_type=instance_type,
195+
initial_instance_count=initial_instance_count,
196+
accelerator_type=accelerator_type,
197+
async_inference_config=async_inference_config,
198+
serverless_inference_config=serverless_inference_config,
199+
inference_recommendation_id=inference_recommendation_id,
200+
)
201+
elif inference_recommender_job_results is not None:
202+
inference_recommendation = self._update_params_for_right_size(
203+
instance_type,
204+
initial_instance_count,
205+
accelerator_type,
206+
serverless_inference_config,
207+
async_inference_config,
208+
)
209+
return inference_recommendation or (instance_type, initial_instance_count)
210+
211+
def _update_params_for_right_size(
180212
self,
181213
instance_type=None,
182214
initial_instance_count=None,
@@ -232,6 +264,161 @@ def _check_inference_recommender_args(
232264
]
233265
return (instance_type, initial_instance_count)
234266

267+
def _update_params_for_recommendation_id(
268+
self,
269+
instance_type,
270+
initial_instance_count,
271+
accelerator_type,
272+
async_inference_config,
273+
serverless_inference_config,
274+
inference_recommendation_id,
275+
):
276+
"""Update parameters with inference recommendation results.
277+
278+
Args:
279+
instance_type (str): The EC2 instance type to deploy this Model to.
280+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
281+
serverless inference, then it is required to deploy a model.
282+
initial_instance_count (int): The initial number of instances to run
283+
in the ``Endpoint`` created from this ``Model``. If not using
284+
serverless inference, then it need to be a number larger or equals
285+
to 1.
286+
accelerator_type (str): Type of Elastic Inference accelerator to
287+
deploy this model for model loading and inference, for example,
288+
'ml.eia1.medium'. If not specified, no Elastic Inference
289+
accelerator will be attached to the endpoint. For more
290+
information:
291+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
292+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
293+
configuration related to async endpoint. Use this configuration when trying
294+
to create async endpoint and make async inference. If empty config object
295+
passed through, will use default config to deploy async endpoint. Deploy a
296+
real-time endpoint if it's None.
297+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
298+
Specifies configuration related to serverless endpoint. Use this configuration
299+
when trying to create serverless endpoint and make serverless inference. If
300+
empty object passed through, will use pre-defined values in
301+
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
302+
instance based endpoint if it's None.
303+
inference_recommendation_id (str): The recommendation id which specifies
304+
the recommendation you picked from inference recommendation job
305+
results and would like to deploy the model and endpoint with
306+
recommended parameters.
307+
Raises:
308+
ValueError: If arguments combination check failed in these circumstances:
309+
- If only one of instance type or instance count specified or
310+
- If recommendation id does not follow the required format or
311+
- If recommendation id is not valid or
312+
- If inference recommendation id is specified along with incompatible parameters
313+
Returns:
314+
(string, int): instance type and associated instance count from selected
315+
inference recommendation id if arguments combination check passed.
316+
"""
317+
318+
if instance_type is not None and initial_instance_count is not None:
319+
LOGGER.warning(
320+
"Both instance_type and initial_instance_count are specified,"
321+
"overriding the recommendation result."
322+
)
323+
return (instance_type, initial_instance_count)
324+
325+
# Validate non-compatible parameters with recommendation id
326+
if bool(instance_type) != bool(initial_instance_count):
327+
raise ValueError(
328+
"Please either do not specify instance_type and initial_instance_count"
329+
"since they are in recommendation, or specify both of them if you want"
330+
"to override the recommendation."
331+
)
332+
if accelerator_type is not None:
333+
raise ValueError("accelerator_type is not compatible with inference_recommendation_id.")
334+
if async_inference_config is not None:
335+
raise ValueError(
336+
"async_inference_config is not compatible with inference_recommendation_id."
337+
)
338+
if serverless_inference_config is not None:
339+
raise ValueError(
340+
"serverless_inference_config is not compatible with inference_recommendation_id."
341+
)
342+
343+
# Validate recommendation id
344+
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
345+
raise ValueError("Inference Recommendation id is not valid")
346+
recommendation_job_name = inference_recommendation_id.split("/")[0]
347+
348+
sage_client = self.sagemaker_session.sagemaker_client
349+
recommendation_res = sage_client.describe_inference_recommendations_job(
350+
JobName=recommendation_job_name
351+
)
352+
input_config = recommendation_res["InputConfig"]
353+
354+
recommendation = next(
355+
(
356+
rec
357+
for rec in recommendation_res["InferenceRecommendations"]
358+
if rec["RecommendationId"] == inference_recommendation_id
359+
),
360+
None,
361+
)
362+
363+
if not recommendation:
364+
raise ValueError(
365+
"inference_recommendation_id does not exist in InferenceRecommendations list"
366+
)
367+
368+
model_config = recommendation["ModelConfiguration"]
369+
envs = (
370+
model_config["EnvironmentParameters"]
371+
if "EnvironmentParameters" in model_config
372+
else None
373+
)
374+
# Update envs
375+
recommend_envs = {}
376+
if envs is not None:
377+
for env in envs:
378+
recommend_envs[env["Key"]] = env["Value"]
379+
self.env.update(recommend_envs)
380+
381+
# Update params with non-compilation recommendation results
382+
if (
383+
"InferenceSpecificationName" not in model_config
384+
and "CompilationJobName" not in model_config
385+
):
386+
387+
if "ModelPackageVersionArn" in input_config:
388+
modelpkg_res = sage_client.describe_model_package(
389+
ModelPackageName=input_config["ModelPackageVersionArn"]
390+
)
391+
self.model_data = modelpkg_res["InferenceSpecification"]["Containers"][0][
392+
"ModelDataUrl"
393+
]
394+
self.image_uri = modelpkg_res["InferenceSpecification"]["Containers"][0]["Image"]
395+
elif "ModelName" in input_config:
396+
model_res = sage_client.describe_model(ModelName=input_config["ModelName"])
397+
self.model_data = model_res["PrimaryContainer"]["ModelDataUrl"]
398+
self.image_uri = model_res["PrimaryContainer"]["Image"]
399+
else:
400+
if "InferenceSpecificationName" in model_config:
401+
modelpkg_res = sage_client.describe_model_package(
402+
ModelPackageName=input_config["ModelPackageVersionArn"]
403+
)
404+
self.model_data = modelpkg_res["AdditionalInferenceSpecificationDefinition"][
405+
"Containers"
406+
][0]["ModelDataUrl"]
407+
self.image_uri = modelpkg_res["AdditionalInferenceSpecificationDefinition"][
408+
"Containers"
409+
][0]["Image"]
410+
elif "CompilationJobName" in model_config:
411+
compilation_res = sage_client.describe_compilation_job(
412+
CompilationJobName=model_config["CompilationJobName"]
413+
)
414+
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
415+
self.image_uri = compilation_res["InferenceImage"]
416+
417+
instance_type = recommendation["EndpointConfiguration"]["InstanceType"]
418+
initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"]
419+
420+
return (instance_type, initial_instance_count)
421+
235422
def _convert_to_endpoint_configurations_json(
236423
self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
237424
):

src/sagemaker/model.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ def deploy(
10351035
volume_size=None,
10361036
model_data_download_timeout=None,
10371037
container_startup_health_check_timeout=None,
1038+
inference_recommendation_id=None,
10381039
**kwargs,
10391040
):
10401041
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1110,31 +1111,24 @@ def deploy(
11101111
inference container to pass health check by SageMaker Hosting. For more information
11111112
about health check see:
11121113
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
1114+
inference_recommendation_id (str): The recommendation id which specifies the
1115+
recommendation you picked from inference recommendation job results and
1116+
would like to deploy the model and endpoint with recommended parameters.
11131117
Raises:
11141118
ValueError: If arguments combination check failed in these circumstances:
11151119
- If no role is specified or
11161120
- If serverless inference config is not specified and instance type and instance
11171121
count are also not specified or
11181122
- If a wrong type of object is provided as serverless inference config or async
1119-
inference config
1123+
inference config or
1124+
- If inference recommendation id is specified along with incompatible parameters
11201125
Returns:
11211126
callable[string, sagemaker.session.Session] or None: Invocation of
11221127
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
11231128
is not None. Otherwise, return None.
11241129
"""
11251130
removed_kwargs("update_endpoint", kwargs)
11261131

1127-
if self.inference_recommender_job_results:
1128-
inference_recommendation = self._check_inference_recommender_args(
1129-
instance_type,
1130-
initial_instance_count,
1131-
accelerator_type,
1132-
serverless_inference_config,
1133-
async_inference_config,
1134-
)
1135-
if inference_recommendation:
1136-
instance_type, initial_instance_count = inference_recommendation
1137-
11381132
self._init_sagemaker_session_if_does_not_exist(instance_type)
11391133

11401134
tags = add_jumpstart_tags(
@@ -1144,6 +1138,20 @@ def deploy(
11441138
if self.role is None:
11451139
raise ValueError("Role can not be null for deploying a model")
11461140

1141+
if (
1142+
inference_recommendation_id is not None
1143+
or self.inference_recommender_job_results is not None
1144+
):
1145+
instance_type, initial_instance_count = self._update_params(
1146+
instance_type=instance_type,
1147+
initial_instance_count=initial_instance_count,
1148+
accelerator_type=accelerator_type,
1149+
async_inference_config=async_inference_config,
1150+
serverless_inference_config=serverless_inference_config,
1151+
inference_recommendation_id=inference_recommendation_id,
1152+
inference_recommender_job_results=self.inference_recommender_job_results,
1153+
)
1154+
11471155
is_async = async_inference_config is not None
11481156
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
11491157
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def deploy(
323323
volume_size=None,
324324
model_data_download_timeout=None,
325325
container_startup_health_check_timeout=None,
326+
inference_recommendation_id=None,
326327
):
327328
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
328329

@@ -347,6 +348,7 @@ def deploy(
347348
model_data_download_timeout=model_data_download_timeout,
348349
container_startup_health_check_timeout=container_startup_health_check_timeout,
349350
update_endpoint=update_endpoint,
351+
inference_recommendation_id=inference_recommendation_id,
350352
)
351353

352354
def _eia_supported(self):

0 commit comments

Comments
 (0)