|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | import logging
|
| 17 | +import re |
17 | 18 |
|
18 | 19 | from typing import List, Dict, Optional
|
19 |
| - |
20 | 20 | import sagemaker
|
21 |
| - |
22 | 21 | from sagemaker.parameter import CategoricalParameter
|
23 | 22 |
|
24 | 23 | INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
|
@@ -176,7 +175,37 @@ def right_size(
|
176 | 175 |
|
177 | 176 | return self
|
178 | 177 |
|
179 |
| - def _check_inference_recommender_args( |
| 178 | + def _update_params( |
| 179 | + self, |
| 180 | + instance_type, |
| 181 | + initial_instance_count, |
| 182 | + accelerator_type, |
| 183 | + async_inference_config, |
| 184 | + serverless_inference_config, |
| 185 | + inference_recommendation_id, |
| 186 | + inference_recommender_job_results, |
| 187 | + ): |
| 188 | + """Check and update params based on inference recommendation id or right size case""" |
| 189 | + if inference_recommendation_id is not None: |
| 190 | + inference_recommendation = self._update_params_for_recommendation_id( |
| 191 | + instance_type=instance_type, |
| 192 | + initial_instance_count=initial_instance_count, |
| 193 | + accelerator_type=accelerator_type, |
| 194 | + async_inference_config=async_inference_config, |
| 195 | + serverless_inference_config=serverless_inference_config, |
| 196 | + inference_recommendation_id=inference_recommendation_id, |
| 197 | + ) |
| 198 | + elif inference_recommender_job_results is not None: |
| 199 | + inference_recommendation = self._update_params_for_right_size( |
| 200 | + instance_type, |
| 201 | + initial_instance_count, |
| 202 | + accelerator_type, |
| 203 | + serverless_inference_config, |
| 204 | + async_inference_config, |
| 205 | + ) |
| 206 | + return inference_recommendation or (instance_type, initial_instance_count) |
| 207 | + |
| 208 | + def _update_params_for_right_size( |
180 | 209 | self,
|
181 | 210 | instance_type=None,
|
182 | 211 | initial_instance_count=None,
|
@@ -232,6 +261,162 @@ def _check_inference_recommender_args(
|
232 | 261 | ]
|
233 | 262 | return (instance_type, initial_instance_count)
|
234 | 263 |
|
| 264 | + def _update_params_for_recommendation_id( |
| 265 | + self, |
| 266 | + instance_type, |
| 267 | + initial_instance_count, |
| 268 | + accelerator_type, |
| 269 | + async_inference_config, |
| 270 | + serverless_inference_config, |
| 271 | + inference_recommendation_id, |
| 272 | + ): |
| 273 | + """Update parameters with inference recommendation results. |
| 274 | +
|
| 275 | + Args: |
| 276 | + instance_type (str): The EC2 instance type to deploy this Model to. |
| 277 | + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using |
| 278 | + serverless inference, then it is required to deploy a model. |
| 279 | + initial_instance_count (int): The initial number of instances to run |
| 280 | + in the ``Endpoint`` created from this ``Model``. If not using |
| 281 | + serverless inference, then it need to be a number larger or equals |
| 282 | + to 1. |
| 283 | + accelerator_type (str): Type of Elastic Inference accelerator to |
| 284 | + deploy this model for model loading and inference, for example, |
| 285 | + 'ml.eia1.medium'. If not specified, no Elastic Inference |
| 286 | + accelerator will be attached to the endpoint. For more |
| 287 | + information: |
| 288 | + https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html |
| 289 | + async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies |
| 290 | + configuration related to async endpoint. Use this configuration when trying |
| 291 | + to create async endpoint and make async inference. If empty config object |
| 292 | + passed through, will use default config to deploy async endpoint. Deploy a |
| 293 | + real-time endpoint if it's None. |
| 294 | + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): |
| 295 | + Specifies configuration related to serverless endpoint. Use this configuration |
| 296 | + when trying to create serverless endpoint and make serverless inference. If |
| 297 | + empty object passed through, will use pre-defined values in |
| 298 | + ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an |
| 299 | + instance based endpoint if it's None. |
| 300 | + inference_recommendation_id (str): The recommendation id which specifies |
| 301 | + the recommendation you picked from inference recommendation job |
| 302 | + results and would like to deploy the model and endpoint with |
| 303 | + recommended parameters. |
| 304 | + Raises: |
| 305 | + ValueError: If arguments combination check failed in these circumstances: |
| 306 | + - If only one of instance type or instance count specified or |
| 307 | + - If recommendation id does not follow the required format or |
| 308 | + - If recommendation id is not valid or |
| 309 | + - If inference recommendation id is specified along with incompatible parameters |
| 310 | + Returns: |
| 311 | + (string, int): instance type and associated instance count from selected |
| 312 | + inference recommendation id if arguments combination check passed. |
| 313 | + """ |
| 314 | + |
| 315 | + if instance_type is not None and initial_instance_count is not None: |
| 316 | + LOGGER.warning( |
| 317 | + "Both instance_type and initial_instance_count are specified," |
| 318 | + "overriding the recommendation result." |
| 319 | + ) |
| 320 | + return (instance_type, initial_instance_count) |
| 321 | + |
| 322 | + # Validate non-compatible parameters with recommendation id |
| 323 | + if bool(instance_type) != bool(initial_instance_count): |
| 324 | + raise ValueError( |
| 325 | + "Please either do not specify instance_type and initial_instance_count" |
| 326 | + "since they are in recommendation, or specify both of them if you want" |
| 327 | + "to override the recommendation." |
| 328 | + ) |
| 329 | + if accelerator_type is not None: |
| 330 | + raise ValueError("accelerator_type is not compatible with inference_recommendation_id.") |
| 331 | + if async_inference_config is not None: |
| 332 | + raise ValueError( |
| 333 | + "async_inference_config is not compatible with inference_recommendation_id." |
| 334 | + ) |
| 335 | + if serverless_inference_config is not None: |
| 336 | + raise ValueError( |
| 337 | + "serverless_inference_config is not compatible with inference_recommendation_id." |
| 338 | + ) |
| 339 | + |
| 340 | + # Validate recommendation id |
| 341 | + if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id): |
| 342 | + raise ValueError("Inference Recommendation id is not valid") |
| 343 | + recommendation_job_name = inference_recommendation_id.split("/")[0] |
| 344 | + |
| 345 | + sage_client = self.sagemaker_session.sagemaker_client |
| 346 | + recommendation_res = sage_client.describe_inference_recommendations_job( |
| 347 | + JobName=recommendation_job_name |
| 348 | + ) |
| 349 | + input_config = recommendation_res["InputConfig"] |
| 350 | + |
| 351 | + recommendation = next( |
| 352 | + ( |
| 353 | + rec |
| 354 | + for rec in recommendation_res["InferenceRecommendations"] |
| 355 | + if rec["RecommendationId"] == inference_recommendation_id |
| 356 | + ), |
| 357 | + None, |
| 358 | + ) |
| 359 | + |
| 360 | + if not recommendation: |
| 361 | + raise ValueError( |
| 362 | + "inference_recommendation_id does not exist in InferenceRecommendations list" |
| 363 | + ) |
| 364 | + |
| 365 | + model_config = recommendation["ModelConfiguration"] |
| 366 | + envs = ( |
| 367 | + model_config["EnvironmentParameters"] |
| 368 | + if "EnvironmentParameters" in model_config |
| 369 | + else None |
| 370 | + ) |
| 371 | + # Update envs |
| 372 | + recommend_envs = {} |
| 373 | + if envs is not None: |
| 374 | + for env in envs: |
| 375 | + recommend_envs[env["Key"]] = env["Value"] |
| 376 | + self.env.update(recommend_envs) |
| 377 | + |
| 378 | + # Update params with non-compilation recommendation results |
| 379 | + if ( |
| 380 | + "InferenceSpecificationName" not in model_config |
| 381 | + and "CompilationJobName" not in model_config |
| 382 | + ): |
| 383 | + |
| 384 | + if "ModelPackageVersionArn" in input_config: |
| 385 | + modelpkg_res = sage_client.describe_model_package( |
| 386 | + ModelPackageName=input_config["ModelPackageVersionArn"] |
| 387 | + ) |
| 388 | + self.model_data = modelpkg_res["InferenceSpecification"]["Containers"][0][ |
| 389 | + "ModelDataUrl" |
| 390 | + ] |
| 391 | + self.image_uri = modelpkg_res["InferenceSpecification"]["Containers"][0]["Image"] |
| 392 | + elif "ModelName" in input_config: |
| 393 | + model_res = sage_client.describe_model(ModelName=input_config["ModelName"]) |
| 394 | + self.model_data = model_res["PrimaryContainer"]["ModelDataUrl"] |
| 395 | + self.image_uri = model_res["PrimaryContainer"]["Image"] |
| 396 | + else: |
| 397 | + # Update params with compilation recommendation results |
| 398 | + if "InferenceSpecificationName" in model_config: |
| 399 | + modelpkg_res = sage_client.describe_model_package( |
| 400 | + ModelPackageName=input_config["ModelPackageVersionArn"] |
| 401 | + ) |
| 402 | + self.model_data = modelpkg_res["AdditionalInferenceSpecificationDefinition"][ |
| 403 | + "Containers" |
| 404 | + ][0]["ModelDataUrl"] |
| 405 | + self.image_uri = modelpkg_res["AdditionalInferenceSpecificationDefinition"][ |
| 406 | + "Containers" |
| 407 | + ][0]["Image"] |
| 408 | + elif "CompilationJobName" in model_config: |
| 409 | + compilation_res = sage_client.describe_compilation_job( |
| 410 | + CompilationJobName=model_config["CompilationJobName"] |
| 411 | + ) |
| 412 | + self.model_data = compilation_res["ModelArtifacts"]["S3ModelArtifacts"] |
| 413 | + self.image_uri = compilation_res["InferenceImage"] |
| 414 | + |
| 415 | + instance_type = recommendation["EndpointConfiguration"]["InstanceType"] |
| 416 | + initial_instance_count = recommendation["EndpointConfiguration"]["InitialInstanceCount"] |
| 417 | + |
| 418 | + return (instance_type, initial_instance_count) |
| 419 | + |
235 | 420 | def _convert_to_endpoint_configurations_json(
|
236 | 421 | self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
|
237 | 422 | ):
|
|
0 commit comments