38
38
SkipTuningComboException ,
39
39
)
40
40
from sagemaker .serve .utils .optimize_utils import (
41
- _is_compatible_with_optimization_job ,
42
41
_extract_model_source ,
43
42
_update_environment_variables ,
43
+ _extract_speculative_draft_model_provider ,
44
+ _is_image_compatible_with_optimization_job ,
45
+ _validate_optimization_inputs ,
44
46
)
45
47
from sagemaker .serve .utils .predictors import (
46
48
DjlLocalModePredictor ,
@@ -628,7 +630,7 @@ def _build_for_jumpstart(self):
628
630
629
631
def _optimize_for_jumpstart (
630
632
self ,
631
- output_path : str ,
633
+ output_path : Optional [ str ] = None ,
632
634
instance_type : Optional [str ] = None ,
633
635
role : Optional [str ] = None ,
634
636
tags : Optional [Tags ] = None ,
@@ -645,7 +647,7 @@ def _optimize_for_jumpstart(
645
647
"""Runs a model optimization job.
646
648
647
649
Args:
648
- output_path (str): Specifies where to store the compiled/quantized model.
650
+ output_path (Optional[ str] ): Specifies where to store the compiled/quantized model.
649
651
instance_type (Optional[str]): Target deployment instance type that
650
652
the model is optimized for.
651
653
role (Optional[str]): Execution role. Defaults to ``None``.
@@ -673,40 +675,30 @@ def _optimize_for_jumpstart(
673
675
"""
674
676
if self ._is_gated_model () and accept_eula is not True :
675
677
raise ValueError (
676
- f"ValueError: Model '{ self .model } ' "
677
- f"requires accepting end-user license agreement (EULA)."
678
+ f"Model '{ self .model } ' requires accepting end-user license agreement (EULA)."
678
679
)
679
680
681
+ _validate_optimization_inputs (
682
+ output_path , instance_type , quantization_config , compilation_config
683
+ )
684
+
680
685
optimization_env_vars = None
681
686
pysdk_model_env_vars = None
682
687
model_source = _extract_model_source (self .pysdk_model .model_data , accept_eula )
683
688
684
689
if speculative_decoding_config :
685
690
self ._set_additional_model_source (speculative_decoding_config )
686
- optimization_env_vars = self .pysdk_model .deployment_config .get ("DeploymentArgs" ). get (
687
- "Environment"
688
- )
691
+ optimization_env_vars = self .pysdk_model .deployment_config .get (
692
+ "DeploymentArgs" , {}
693
+ ). get ( "Environment" )
689
694
else :
690
- image_uri = None
691
- if quantization_config and quantization_config .get ("Image" ):
692
- image_uri = quantization_config .get ("Image" )
693
- elif compilation_config and compilation_config .get ("Image" ):
694
- image_uri = compilation_config .get ("Image" )
695
- instance_type = (
696
- instance_type
697
- or self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get ("InstanceType" )
698
- or _get_nb_instance ()
699
- )
700
- if not _is_compatible_with_optimization_job (instance_type , image_uri ):
701
- deployment_config = self ._find_compatible_deployment_config (None )
702
- if deployment_config :
703
- optimization_env_vars = deployment_config .get ("DeploymentArgs" ).get (
704
- "Environment"
705
- )
706
- self .pysdk_model .set_deployment_config (
707
- config_name = deployment_config .get ("DeploymentConfigName" ),
708
- instance_type = deployment_config .get ("InstanceType" ),
709
- )
695
+ deployment_config = self ._find_compatible_deployment_config (None )
696
+ if deployment_config :
697
+ optimization_env_vars = deployment_config .get ("DeploymentArgs" ).get ("Environment" )
698
+ self .pysdk_model .set_deployment_config (
699
+ config_name = deployment_config .get ("DeploymentConfigName" ),
700
+ instance_type = deployment_config .get ("InstanceType" ),
701
+ )
710
702
711
703
optimization_env_vars = _update_environment_variables (optimization_env_vars , env_vars )
712
704
@@ -736,7 +728,7 @@ def _optimize_for_jumpstart(
736
728
}
737
729
738
730
if optimization_env_vars :
739
- create_optimization_job_args ["Environment " ] = optimization_env_vars
731
+ create_optimization_job_args ["OptimizationEnvironment " ] = optimization_env_vars
740
732
if max_runtime_in_sec :
741
733
create_optimization_job_args ["StoppingCondition" ] = {
742
734
"MaxRuntimeInSeconds" : max_runtime_in_sec
@@ -766,18 +758,26 @@ def _is_gated_model(self, model=None) -> bool:
766
758
return "private" in s3_uri
767
759
768
760
def _set_additional_model_source (
769
- self , speculative_decoding_config : Optional [Dict [str , Any ]] = None
761
+ self ,
762
+ speculative_decoding_config : Optional [Dict [str , Any ]] = None ,
763
+ accept_eula : Optional [bool ] = None ,
770
764
) -> None :
771
765
"""Set Additional Model Source to ``this`` model.
772
766
773
767
Args:
774
768
speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config.
769
+ accept_eula (Optional[bool]): For models that require a Model Access Config.
775
770
"""
776
771
if speculative_decoding_config :
777
- model_provider : str = speculative_decoding_config [ "ModelProvider" ]
772
+ model_provider = _extract_speculative_draft_model_provider ( speculative_decoding_config )
778
773
779
774
if model_provider .lower () == "sagemaker" :
780
- if not self ._is_speculation_enabled (self .pysdk_model .deployment_config ):
775
+ if (
776
+ self .pysdk_model .deployment_config .get ("DeploymentArgs" , {}).get (
777
+ "AdditionalDataSources"
778
+ )
779
+ is None
780
+ ):
781
781
deployment_config = self ._find_compatible_deployment_config (
782
782
speculative_decoding_config
783
783
)
@@ -786,21 +786,30 @@ def _set_additional_model_source(
786
786
config_name = deployment_config .get ("DeploymentConfigName" ),
787
787
instance_type = deployment_config .get ("InstanceType" ),
788
788
)
789
- self .pysdk_model .add_tags (
790
- {"key" : Tag .SPECULATIVE_DRAFT_MODL_PROVIDER , "value" : "sagemaker" },
791
- )
792
789
else :
793
790
raise ValueError (
794
791
"Cannot find deployment config compatible for optimization job."
795
792
)
793
+
794
+ self .pysdk_model .add_tags (
795
+ {"key" : Tag .SPECULATIVE_DRAFT_MODL_PROVIDER , "value" : "sagemaker" },
796
+ )
796
797
else :
797
798
s3_uri = speculative_decoding_config .get ("ModelSource" )
798
799
if not s3_uri :
799
800
raise ValueError ("Custom S3 Uri cannot be none." )
800
801
801
- self .pysdk_model .additional_model_data_sources ["speculative_decoding" ][0 ][
802
- "s3_data_source"
803
- ]["s3_uri" ] = s3_uri
802
+ # TODO: Set correct channel name.
803
+ additional_model_data_source = {
804
+ "ChannelName" : "DraftModelName" ,
805
+ "S3DataSource" : {"S3Uri" : s3_uri },
806
+ }
807
+ if accept_eula :
808
+ additional_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = {
809
+ "ACCEPT_EULA" : True
810
+ }
811
+
812
+ self .pysdk_model .additional_model_data_sources = [additional_model_data_source ]
804
813
self .pysdk_model .add_tags (
805
814
{"key" : Tag .SPECULATIVE_DRAFT_MODL_PROVIDER , "value" : "customer" },
806
815
)
@@ -816,36 +825,20 @@ def _find_compatible_deployment_config(
816
825
Returns:
817
826
Optional[Dict[str, Any]]: A compatible model deployment config for optimization job.
818
827
"""
828
+ model_provider = _extract_speculative_draft_model_provider (speculative_decoding_config )
819
829
for deployment_config in self .pysdk_model .list_deployment_configs ():
820
- instance_type = deployment_config .get ("deployment_config" ).get ("InstanceType" )
821
- image_uri = deployment_config .get ("deployment_config" ).get ("ImageUri" )
822
-
823
- if _is_compatible_with_optimization_job (instance_type , image_uri ):
824
- if not speculative_decoding_config :
825
- return deployment_config
830
+ image_uri = deployment_config .get ("deployment_config" , {}).get ("ImageUri" )
826
831
827
- if self ._is_speculation_enabled (deployment_config ):
832
+ if _is_image_compatible_with_optimization_job (image_uri ):
833
+ if (
834
+ model_provider == "sagemaker"
835
+ and deployment_config .get ("DeploymentArgs" , {}).get ("AdditionalDataSources" )
836
+ ) or model_provider == "custom" :
828
837
return deployment_config
829
838
830
- return None
831
-
832
- def _is_speculation_enabled (self , deployment_config : Optional [Dict [str , Any ]]) -> bool :
833
- """Checks whether speculative is enabled for the given deployment config.
839
+ # There's no matching config from jumpstart to add sagemaker draft model location
840
+ if model_provider == "sagemaker" :
841
+ return None
834
842
835
- Args:
836
- deployment_config (Dict[str, Any]): A deployment config.
837
-
838
- Returns:
839
- bool: Whether speculative is enabled for this deployment config.
840
- """
841
- if deployment_config is None :
842
- return False
843
-
844
- acceleration_configs = deployment_config .get ("AccelerationConfigs" )
845
- if acceleration_configs :
846
- for acceleration_config in acceleration_configs :
847
- if acceleration_config .get (
848
- "type" , "default"
849
- ).lower () == "speculative" and acceleration_config .get ("enabled" ):
850
- return True
851
- return False
843
+ # fall back to the default jumpstart model deployment config for optimization job
844
+ return self .pysdk_model .deployment_config
0 commit comments