@@ -237,7 +237,12 @@ def _update_params(
237
237
async_inference_config ,
238
238
explainer_config ,
239
239
)
240
- return inference_recommendation or (instance_type , initial_instance_count )
240
+
241
+ return (
242
+ inference_recommendation
243
+ if inference_recommendation
244
+ else (instance_type , initial_instance_count )
245
+ )
241
246
242
247
def _update_params_for_right_size (
243
248
self ,
@@ -365,12 +370,6 @@ def _update_params_for_recommendation_id(
365
370
return (instance_type , initial_instance_count )
366
371
367
372
# Validate non-compatible parameters with recommendation id
368
- if bool (instance_type ) != bool (initial_instance_count ):
369
- raise ValueError (
370
- "Please either do not specify instance_type and initial_instance_count"
371
- "since they are in recommendation, or specify both of them if you want"
372
- "to override the recommendation."
373
- )
374
373
if accelerator_type is not None :
375
374
raise ValueError ("accelerator_type is not compatible with inference_recommendation_id." )
376
375
if async_inference_config is not None :
@@ -386,30 +385,38 @@ def _update_params_for_recommendation_id(
386
385
387
386
# Validate recommendation id
388
387
if not re .match (r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$" , inference_recommendation_id ):
389
- raise ValueError ("Inference Recommendation id is not valid" )
390
- recommendation_job_name = inference_recommendation_id .split ("/" )[0 ]
388
+ raise ValueError ("inference_recommendation_id is not valid" )
389
+ job_or_model_name = inference_recommendation_id .split ("/" )[0 ]
391
390
392
391
sage_client = self .sagemaker_session .sagemaker_client
393
- recommendation_res = sage_client .describe_inference_recommendations_job (
394
- JobName = recommendation_job_name
392
+ # Get recommendation from right size job and model
393
+ (
394
+ right_size_recommendation ,
395
+ model_recommendation ,
396
+ right_size_job_res ,
397
+ ) = self ._get_recommendation (
398
+ sage_client = sage_client ,
399
+ job_or_model_name = job_or_model_name ,
400
+ inference_recommendation_id = inference_recommendation_id ,
395
401
)
396
- input_config = recommendation_res ["InputConfig" ]
397
402
398
- recommendation = next (
399
- (
400
- rec
401
- for rec in recommendation_res ["InferenceRecommendations" ]
402
- if rec ["RecommendationId" ] == inference_recommendation_id
403
- ),
404
- None ,
405
- )
403
+ # Update params beased on model recommendation
404
+ if model_recommendation :
405
+ if initial_instance_count is None :
406
+ raise ValueError ("Must specify model recommendation id and instance count." )
407
+ self .env .update (model_recommendation ["Environment" ])
408
+ instance_type = model_recommendation ["InstanceType" ]
409
+ return (instance_type , initial_instance_count )
406
410
407
- if not recommendation :
411
+ # Update params based on default inference recommendation
412
+ if bool (instance_type ) != bool (initial_instance_count ):
408
413
raise ValueError (
409
- "inference_recommendation_id does not exist in InferenceRecommendations list"
414
+ "instance_type and initial_instance_count are mutually exclusive with"
415
+ "recommendation id since they are in recommendation."
416
+ "Please specify both of them if you want to override the recommendation."
410
417
)
411
-
412
- model_config = recommendation ["ModelConfiguration" ]
418
+ input_config = right_size_job_res [ "InputConfig" ]
419
+ model_config = right_size_recommendation ["ModelConfiguration" ]
413
420
envs = (
414
421
model_config ["EnvironmentParameters" ]
415
422
if "EnvironmentParameters" in model_config
@@ -458,8 +465,10 @@ def _update_params_for_recommendation_id(
458
465
self .model_data = compilation_res ["ModelArtifacts" ]["S3ModelArtifacts" ]
459
466
self .image_uri = compilation_res ["InferenceImage" ]
460
467
461
- instance_type = recommendation ["EndpointConfiguration" ]["InstanceType" ]
462
- initial_instance_count = recommendation ["EndpointConfiguration" ]["InitialInstanceCount" ]
468
+ instance_type = right_size_recommendation ["EndpointConfiguration" ]["InstanceType" ]
469
+ initial_instance_count = right_size_recommendation ["EndpointConfiguration" ][
470
+ "InitialInstanceCount"
471
+ ]
463
472
464
473
return (instance_type , initial_instance_count )
465
474
@@ -527,3 +536,77 @@ def _convert_to_stopping_conditions_json(
527
536
threshold .to_json for threshold in model_latency_thresholds
528
537
]
529
538
return stopping_conditions
539
+
540
+ def _get_recommendation (self , sage_client , job_or_model_name , inference_recommendation_id ):
541
+ """Get recommendation from right size job and model"""
542
+ right_size_recommendation , model_recommendation , right_size_job_res = None , None , None
543
+ right_size_recommendation , right_size_job_res = self ._get_right_size_recommendation (
544
+ sage_client = sage_client ,
545
+ job_or_model_name = job_or_model_name ,
546
+ inference_recommendation_id = inference_recommendation_id ,
547
+ )
548
+ if right_size_recommendation is None :
549
+ model_recommendation = self ._get_model_recommendation (
550
+ sage_client = sage_client ,
551
+ job_or_model_name = job_or_model_name ,
552
+ inference_recommendation_id = inference_recommendation_id ,
553
+ )
554
+ if model_recommendation is None :
555
+ raise ValueError ("inference_recommendation_id is not valid" )
556
+
557
+ return right_size_recommendation , model_recommendation , right_size_job_res
558
+
559
+ def _get_right_size_recommendation (
560
+ self ,
561
+ sage_client ,
562
+ job_or_model_name ,
563
+ inference_recommendation_id ,
564
+ ):
565
+ """Get recommendation from right size job"""
566
+ right_size_recommendation , right_size_job_res = None , None
567
+ try :
568
+ right_size_job_res = sage_client .describe_inference_recommendations_job (
569
+ JobName = job_or_model_name
570
+ )
571
+ if right_size_job_res :
572
+ right_size_recommendation = self ._search_recommendation (
573
+ recommendation_list = right_size_job_res ["InferenceRecommendations" ],
574
+ inference_recommendation_id = inference_recommendation_id ,
575
+ )
576
+ except sage_client .exceptions .ResourceNotFound :
577
+ pass
578
+
579
+ return right_size_recommendation , right_size_job_res
580
+
581
+ def _get_model_recommendation (
582
+ self ,
583
+ sage_client ,
584
+ job_or_model_name ,
585
+ inference_recommendation_id ,
586
+ ):
587
+ """Get recommendation from model"""
588
+ model_recommendation = None
589
+ try :
590
+ model_res = sage_client .describe_model (ModelName = job_or_model_name )
591
+ if model_res :
592
+ model_recommendation = self ._search_recommendation (
593
+ recommendation_list = model_res ["DeploymentRecommendation" ][
594
+ "RealTimeInferenceRecommendations"
595
+ ],
596
+ inference_recommendation_id = inference_recommendation_id ,
597
+ )
598
+ except sage_client .exceptions .ResourceNotFound :
599
+ pass
600
+
601
+ return model_recommendation
602
+
603
+ def _search_recommendation (self , recommendation_list , inference_recommendation_id ):
604
+ """Search recommendation based on recommendation id"""
605
+ return next (
606
+ (
607
+ rec
608
+ for rec in recommendation_list
609
+ if rec ["RecommendationId" ] == inference_recommendation_id
610
+ ),
611
+ None ,
612
+ )
0 commit comments