Skip to content

Commit 8372c5e

Browse files
jinpengqiJinpeng Qi
andcommitted
Inference recommendation id deployment support (aws#800)
* Add IR id as input and validate compatible parameters * use inference_recommender_mixin * Add inf rec id support for deploy method * Invoke describe boto3 apis directly and wrap params update in mixin * Refactor mixin * Get compiled results from modelpkg * Fix describe calls * Add check for initial_instance_count * Fix flake8 and docstyle errors * Black reformat * Add recommendation id deployment support for estimator * Remove ir id for auto_ml model * Rebase on right size change Co-authored-by: Jinpeng Qi <[email protected]>
1 parent 6322794 commit 8372c5e

File tree

9 files changed

+629
-13
lines changed

9 files changed

+629
-13
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,7 @@ def deploy(
13421342
volume_size=None,
13431343
model_data_download_timeout=None,
13441344
container_startup_health_check_timeout=None,
1345+
inference_recommendation_id=None,
13451346
**kwargs,
13461347
):
13471348
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -1419,6 +1420,9 @@ def deploy(
14191420
inference container to pass health check by SageMaker Hosting. For more information
14201421
about health check see:
14211422
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
1423+
inference_recommendation_id (str): The recommendation id which specifies the
1424+
recommendation you picked from inference recommendation job results and
1425+
would like to deploy the model and endpoint with recommended parameters.
14221426
**kwargs: Passed to invocation of ``create_model()``.
14231427
Implementations may customize ``create_model()`` to accept
14241428
``**kwargs`` to customize model creation during deploy.
@@ -1483,6 +1487,7 @@ def deploy(
14831487
volume_size=volume_size,
14841488
model_data_download_timeout=model_data_download_timeout,
14851489
container_startup_health_check_timeout=container_startup_health_check_timeout,
1490+
inference_recommendation_id=inference_recommendation_id,
14861491
)
14871492

14881493
def register(

src/sagemaker/huggingface/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def deploy(
209209
volume_size=None,
210210
model_data_download_timeout=None,
211211
container_startup_health_check_timeout=None,
212+
inference_recommendation_id=None,
212213
**kwargs,
213214
):
214215
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -282,6 +283,9 @@ def deploy(
282283
inference container to pass health check by SageMaker Hosting. For more information
283284
about health check see:
284285
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
286+
inference_recommendation_id (str): The recommendation id which specifies the
287+
recommendation you picked from inference recommendation job results and
288+
would like to deploy the model and endpoint with recommended parameters.
285289
Raises:
286290
ValueError: If arguments combination check failed in these circumstances:
287291
- If no role is specified or
@@ -317,6 +321,7 @@ def deploy(
317321
volume_size=volume_size,
318322
model_data_download_timeout=model_data_download_timeout,
319323
container_startup_health_check_timeout=container_startup_health_check_timeout,
324+
inference_recommendation_id=inference_recommendation_id,
320325
)
321326

322327
def register(

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 188 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
import re
1718

1819
from typing import List, Dict, Optional
19-
2020
import sagemaker
21-
2221
from sagemaker.parameter import CategoricalParameter
2322

2423
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -176,7 +175,37 @@ def right_size(
176175

177176
return self
178177

179-
def _check_inference_recommender_args(
178+
def _update_params(
179+
self,
180+
instance_type,
181+
initial_instance_count,
182+
accelerator_type,
183+
async_inference_config,
184+
serverless_inference_config,
185+
inference_recommendation_id,
186+
inference_recommender_job_results,
187+
):
188+
"""Check and update params based on inference recommendation id or right size case"""
189+
if inference_recommendation_id is not None:
190+
inference_recommendation = self._update_params_for_recommendation_id(
191+
instance_type=instance_type,
192+
initial_instance_count=initial_instance_count,
193+
accelerator_type=accelerator_type,
194+
async_inference_config=async_inference_config,
195+
serverless_inference_config=serverless_inference_config,
196+
inference_recommendation_id=inference_recommendation_id,
197+
)
198+
elif inference_recommender_job_results is not None:
199+
inference_recommendation = self._update_params_for_right_size(
200+
instance_type,
201+
initial_instance_count,
202+
accelerator_type,
203+
serverless_inference_config,
204+
async_inference_config,
205+
)
206+
return inference_recommendation or (instance_type, initial_instance_count)
207+
208+
def _update_params_for_right_size(
180209
self,
181210
instance_type=None,
182211
initial_instance_count=None,
@@ -232,6 +261,162 @@ def _check_inference_recommender_args(
232261
]
233262
return (instance_type, initial_instance_count)
234263

264+
def _update_params_for_recommendation_id(
265+
self,
266+
instance_type,
267+
initial_instance_count,
268+
accelerator_type,
269+
async_inference_config,
270+
serverless_inference_config,
271+
inference_recommendation_id,
272+
):
273+
"""Update parameters with inference recommendation results.
274+
275+
Args:
276+
instance_type (str): The EC2 instance type to deploy this Model to.
277+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
278+
serverless inference, then it is required to deploy a model.
279+
initial_instance_count (int): The initial number of instances to run
280+
in the ``Endpoint`` created from this ``Model``. If not using
281+
serverless inference, then it need to be a number larger or equals
282+
to 1.
283+
accelerator_type (str): Type of Elastic Inference accelerator to
284+
deploy this model for model loading and inference, for example,
285+
'ml.eia1.medium'. If not specified, no Elastic Inference
286+
accelerator will be attached to the endpoint. For more
287+
information:
288+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
289+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
290+
configuration related to async endpoint. Use this configuration when trying
291+
to create async endpoint and make async inference. If empty config object
292+
passed through, will use default config to deploy async endpoint. Deploy a
293+
real-time endpoint if it's None.
294+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
295+
Specifies configuration related to serverless endpoint. Use this configuration
296+
when trying to create serverless endpoint and make serverless inference. If
297+
empty object passed through, will use pre-defined values in
298+
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
299+
instance based endpoint if it's None.
300+
inference_recommendation_id (str): The recommendation id which specifies
301+
the recommendation you picked from inference recommendation job
302+
results and would like to deploy the model and endpoint with
303+
recommended parameters.
304+
Raises:
305+
ValueError: If arguments combination check failed in these circumstances:
306+
- If only one of instance type or instance count specified or
307+
- If recommendation id does not follow the required format or
308+
- If recommendation id is not valid or
309+
- If inference recommendation id is specified along with incompatible parameters
310+
Returns:
311+
(string, int): instance type and associated instance count from selected
312+
inference recommendation id if arguments combination check passed.
313+
"""
314+
315+
if instance_type is not None and initial_instance_count is not None:
316+
LOGGER.warning(
317+
"Both instance_type and initial_instance_count are specified,"
318+
"overriding the recommendation result."
319+
)
320+
return (instance_type, initial_instance_count)
321+
322+
# Validate non-compatible parameters with recommendation id
323+
if bool(instance_type) != bool(initial_instance_count):
324+
raise ValueError(
325+
"Please either do not specify instance_type and initial_instance_count"
326+
"since they are in recommendation, or specify both of them if you want"
327+
"to override the recommendation."
328+
)
329+
if accelerator_type is not None:
330+
raise ValueError("accelerator_type is not compatible with inference_recommendation_id.")
331+
if async_inference_config is not None:
332+
raise ValueError(
333+
"async_inference_config is not compatible with inference_recommendation_id."
334+
)
335+
if serverless_inference_config is not None:
336+
raise ValueError(
337+
"serverless_inference_config is not compatible with inference_recommendation_id."
338+
)
339+
340+
# Validate recommendation id
341+
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
342+
raise ValueError("Inference Recommendation id is not valid")
343+
recommendation_job_name = inference_recommendation_id.split("/")[0]
344+
345+
sage_client = self.sagemaker_session.sagemaker_client
346+
recommendation_res = sage_client.describe_inference_recommendations_job(
347+
JobName=recommendation_job_name
348+
)
349+
input_config = recommendation_res["InputConfig"]
350+
351+
recommendation = next(
352+
(
353+
rec
354+
for rec in recommendation_res["InferenceRecommendations"]
355+
if rec["RecommendationId"] == inference_recommendation_id
356+
),
357+
None,
358+
)
359+
360+
if not recommendation:
361+
raise ValueError(
362+
"inference_recommendation_id does not exist in InferenceRecommendations list"
363+
)
364+
365+
model_config = recommendation["ModelConfiguration"]
366+
envs = (
367+
model_config["EnvironmentParameters"]
368+
if "EnvironmentParameters" in model_config
369+
else None
370+
)
371+
# Update envs
372+
recommend_envs = {}
373+
if envs is not None:
374+
for env in envs:
375+
recommend_envs[env["Key"]] = env["Value"]
376+
self.env.update(recommend_envs)
377+
378+
# Update params with non-compilation recommendation results
379+
if (
380+
"InferenceSpecificationName" not in model_config
381+
and "CompilationJobName" not in model_config
382+
):
383+
384+
if "ModelPackageVersionArn" in input_config:
385+
modelpkg_res = sage_client.describe_model_package(
386+
ModelPackageName=input_config["ModelPackageVersionArn"]
387+
)
388+
self.model_data = modelpkg_res["InferenceSpecification"]["Containers"][0][
389+
"ModelDataUrl"
390+
]
391+
self.image_uri = modelpkg_res["InferenceSpecification"]["Containers"][0]["Image"]
392+
elif "ModelName" in input_config:
393+
model_res = sage_client.describe_model(ModelName=input_config["ModelName"])
394+
self.model_data = model_res["PrimaryContainer"]["ModelDataUrl"]
395+
self.image_uri = model_res["PrimaryContainer"]["Image"]
396+
else:
397+
# Update params with compilation recommendation results
398+
if "InferenceSpecificationName" in model_config:
399+
modelpkg_res = sage_client.describe_model_package(
400+
ModelPackageName=input_config["ModelPackageVersionArn"]
401+
)
402+
self.model_data = modelpkg_res["AdditionalInferenceSpecificationDefinition"][
403+
"Containers"
404+
][0]["ModelDataUrl"]
405+
self.image_uri = modelpkg_res["AdditionalInferenceSpecificationDefinition"][
406+
"Containers"
407+
][0]["Image"]
408+
elif "CompilationJobName" in model_config:
409+
compilation_res = sage_client.describe_compilation_job(
410+
CompilationJobName=model_config["CompilationJobName"]
411+
)
412+
self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"]
413+
self.image_uri = compilation_res["InferenceImage"]
414+
415+
instance_type = recommendation["EndpointConfiguration"]["InstanceType"]
416+
initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"]
417+
418+
return (instance_type, initial_instance_count)
419+
235420
def _convert_to_endpoint_configurations_json(
236421
self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
237422
):

src/sagemaker/model.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ def deploy(
10351035
volume_size=None,
10361036
model_data_download_timeout=None,
10371037
container_startup_health_check_timeout=None,
1038+
inference_recommendation_id=None,
10381039
**kwargs,
10391040
):
10401041
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1110,30 +1111,37 @@ def deploy(
11101111
inference container to pass health check by SageMaker Hosting. For more information
11111112
about health check see:
11121113
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
1114+
inference_recommendation_id (str): The recommendation id which specifies the
1115+
recommendation you picked from inference recommendation job results and
1116+
would like to deploy the model and endpoint with recommended parameters.
11131117
Raises:
11141118
ValueError: If arguments combination check failed in these circumstances:
11151119
- If no role is specified or
11161120
- If serverless inference config is not specified and instance type and instance
11171121
count are also not specified or
11181122
- If a wrong type of object is provided as serverless inference config or async
1119-
inference config
1123+
inference config or
1124+
- If inference recommendation id is specified along with incompatible parameters
11201125
Returns:
11211126
callable[string, sagemaker.session.Session] or None: Invocation of
11221127
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
11231128
is not None. Otherwise, return None.
11241129
"""
11251130
removed_kwargs("update_endpoint", kwargs)
11261131

1127-
if self.inference_recommender_job_results:
1128-
inference_recommendation = self._check_inference_recommender_args(
1129-
instance_type,
1130-
initial_instance_count,
1131-
accelerator_type,
1132-
serverless_inference_config,
1133-
async_inference_config,
1132+
if (
1133+
inference_recommendation_id is not None
1134+
or self.inference_recommender_job_results is not None
1135+
):
1136+
instance_type, initial_instance_count = self._update_params(
1137+
instance_type=instance_type,
1138+
initial_instance_count=initial_instance_count,
1139+
accelerator_type=accelerator_type,
1140+
async_inference_config=async_inference_config,
1141+
serverless_inference_config=serverless_inference_config,
1142+
inference_recommendation_id=inference_recommendation_id,
1143+
inference_recommender_job_results=self.inference_recommender_job_results,
11341144
)
1135-
if inference_recommendation:
1136-
instance_type, initial_instance_count = inference_recommendation
11371145

11381146
self._init_sagemaker_session_if_does_not_exist(instance_type)
11391147

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def deploy(
323323
volume_size=None,
324324
model_data_download_timeout=None,
325325
container_startup_health_check_timeout=None,
326+
inference_recommendation_id=None,
326327
):
327328
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
328329

@@ -347,6 +348,7 @@ def deploy(
347348
model_data_download_timeout=model_data_download_timeout,
348349
container_startup_health_check_timeout=container_startup_health_check_timeout,
349350
update_endpoint=update_endpoint,
351+
inference_recommendation_id=inference_recommendation_id,
350352
)
351353

352354
def _eia_supported(self):

0 commit comments

Comments
 (0)