@@ -381,58 +381,34 @@ def _update_params_for_recommendation_id(
381
381
# Validate recommendation id
382
382
if not re .match (r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$" , inference_recommendation_id ):
383
383
raise ValueError ("Inference Recommendation id is not valid" )
384
- recommendation_job_name = inference_recommendation_id .split ("/" )[0 ]
384
+ job_or_model_name = inference_recommendation_id .split ("/" )[0 ]
385
385
386
386
sage_client = self .sagemaker_session .sagemaker_client
387
387
388
- # Retrieve model or inference recommendation job details
389
- recommendation_res , model_res = None , None
390
- try :
391
- recommendation_res = sage_client .describe_inference_recommendations_job (
392
- JobName = recommendation_job_name
393
- )
394
- except sage_client .exceptions .ResourceNotFound :
395
- pass
396
- try :
397
- model_res = sage_client .describe_model (ModelName = recommendation_job_name )
398
- except sage_client .exceptions .ResourceNotFound :
399
- pass
400
- if recommendation_res is None and model_res is None :
401
- raise ValueError ("Inference Recommendation id is not valid" )
388
+ # Desribe inference recommendation job and model details
389
+ recommendation_res , model_res = self ._describe_recommendation_job_and_model (
390
+ sage_client = sage_client ,
391
+ job_or_model_name = job_or_model_name ,
392
+ )
402
393
403
- # Search the recommendation from above describe result lists
404
- inference_recommendation , instant_recommendation = None , None
405
- if recommendation_res :
406
- inference_recommendation = next (
407
- (
408
- rec
409
- for rec in recommendation_res ["InferenceRecommendations" ]
410
- if rec ["RecommendationId" ] == inference_recommendation_id
411
- ),
412
- None ,
413
- )
414
- if model_res :
415
- instant_recommendation = next (
416
- (
417
- rec
418
- for rec in model_res ["DeploymentRecommendation" ][
419
- "RealTimeInferenceRecommendations"
420
- ]
421
- if rec ["RecommendationId" ] == inference_recommendation_id
422
- ),
423
- None ,
424
- )
425
- if inference_recommendation is None and instant_recommendation is None :
426
- raise ValueError ("Inference Recommendation id does not exist" )
394
+ # Search the recommendation from above describe results
395
+ (
396
+ right_size_recommendation ,
397
+ model_recommendation ,
398
+ ) = self ._get_right_size_and_model_recommendation (
399
+ recommendation_res = recommendation_res ,
400
+ model_res = model_res ,
401
+ inference_recommendation_id = inference_recommendation_id ,
402
+ )
427
403
428
- # Update params beased on instant recommendation
429
- if instant_recommendation :
404
+ # Update params beased on model recommendation
405
+ if model_recommendation :
430
406
if initial_instance_count is None :
431
407
raise ValueError (
432
- "Please specify initial_instance_count with instant recommendation id"
408
+ "Please specify initial_instance_count with model recommendation id"
433
409
)
434
- self .env .update (instant_recommendation ["Environment" ])
435
- instance_type = instant_recommendation ["InstanceType" ]
410
+ self .env .update (model_recommendation ["Environment" ])
411
+ instance_type = model_recommendation ["InstanceType" ]
436
412
return (instance_type , initial_instance_count )
437
413
438
414
# Update params based on default inference recommendation
@@ -443,7 +419,7 @@ def _update_params_for_recommendation_id(
443
419
"to override the recommendation."
444
420
)
445
421
input_config = recommendation_res ["InputConfig" ]
446
- model_config = inference_recommendation ["ModelConfiguration" ]
422
+ model_config = right_size_recommendation ["ModelConfiguration" ]
447
423
envs = (
448
424
model_config ["EnvironmentParameters" ]
449
425
if "EnvironmentParameters" in model_config
@@ -492,8 +468,8 @@ def _update_params_for_recommendation_id(
492
468
self .model_data = compilation_res ["ModelArtifacts" ]["S3ModelArtifacts" ]
493
469
self .image_uri = compilation_res ["InferenceImage" ]
494
470
495
- instance_type = inference_recommendation ["EndpointConfiguration" ]["InstanceType" ]
496
- initial_instance_count = inference_recommendation ["EndpointConfiguration" ][
471
+ instance_type = right_size_recommendation ["EndpointConfiguration" ]["InstanceType" ]
472
+ initial_instance_count = right_size_recommendation ["EndpointConfiguration" ][
497
473
"InitialInstanceCount"
498
474
]
499
475
@@ -563,3 +539,57 @@ def _convert_to_stopping_conditions_json(
563
539
threshold .to_json for threshold in model_latency_thresholds
564
540
]
565
541
return stopping_conditions
542
+
543
+ def _get_right_size_and_model_recommendation (
544
+ self ,
545
+ model_res = None ,
546
+ recommendation_res = None ,
547
+ inference_recommendation_id = None ,
548
+ ):
549
+ """Get recommendation from right size job or model"""
550
+ right_size_recommendation , model_recommendation = None , None
551
+ if recommendation_res :
552
+ right_size_recommendation = self ._get_recommendation (
553
+ recommendation_list = recommendation_res ["InferenceRecommendations" ],
554
+ inference_recommendation_id = inference_recommendation_id ,
555
+ )
556
+ if model_res :
557
+ model_recommendation = self ._get_recommendation (
558
+ recommendation_list = model_res ["DeploymentRecommendation" ][
559
+ "RealTimeInferenceRecommendations"
560
+ ],
561
+ inference_recommendation_id = inference_recommendation_id ,
562
+ )
563
+ if right_size_recommendation is None and model_recommendation is None :
564
+ raise ValueError ("Inference Recommendation id is not valid" )
565
+
566
+ return right_size_recommendation , model_recommendation
567
+
568
+ def _get_recommendation (self , recommendation_list , inference_recommendation_id ):
569
+ """Get recommendation based on recommendation id"""
570
+ return next (
571
+ (
572
+ rec
573
+ for rec in recommendation_list
574
+ if rec ["RecommendationId" ] == inference_recommendation_id
575
+ ),
576
+ None ,
577
+ )
578
+
579
+ def _describe_recommendation_job_and_model (self , sage_client , job_or_model_name ):
580
+ """Describe inference recommendation job and model results"""
581
+ recommendation_res , model_res = None , None
582
+ try :
583
+ recommendation_res = sage_client .describe_inference_recommendations_job (
584
+ JobName = job_or_model_name
585
+ )
586
+ except sage_client .exceptions .ResourceNotFound :
587
+ pass
588
+ try :
589
+ model_res = sage_client .describe_model (ModelName = job_or_model_name )
590
+ except sage_client .exceptions .ResourceNotFound :
591
+ pass
592
+ if recommendation_res is None and model_res is None :
593
+ raise ValueError ("Inference Recommendation id is not valid" )
594
+
595
+ return recommendation_res , model_res
0 commit comments