Skip to content

Commit f3b3504

Browse files
makungaj1Jonathan Makunga
andauthored
update: Add optimize to ModelBuilder JS (aws#1485)
* MB JS Optimize * UT * Refactore * UT * UT * refactore * refactore --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 701b788 commit f3b3504

File tree

7 files changed

+179
-186
lines changed

7 files changed

+179
-186
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,6 +2568,7 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
25682568
"compute_resource_requirements",
25692569
"model_data_download_timeout",
25702570
"container_startup_health_check_timeout",
2571+
"additional_data_sources",
25712572
]
25722573

25732574
def __init__(
@@ -2597,6 +2598,7 @@ def __init__(
25972598
self.supported_instance_types = resolved_config.get(
25982599
"supported_inference_instance_types"
25992600
)
2601+
self.additional_data_sources = resolved_config.get("hosting_additional_data_sources")
26002602

26012603

26022604
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):

src/sagemaker/jumpstart/utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,20 +1364,3 @@ def wrapped_f(*args, **kwargs):
13641364
if _func is None:
13651365
return wrapper_cache
13661366
return wrapper_cache(_func)
1367-
1368-
1369-
def _extract_image_tag_and_version(image_uri: str) -> Tuple[Optional[str], Optional[str]]:
1370-
"""Extract Image tag and version from image URI.
1371-
1372-
Args:
1373-
image_uri (str): Image URI.
1374-
1375-
Returns:
1376-
Tuple[Optional[str], Optional[str]]: The tag and version of the image.
1377-
"""
1378-
if image_uri is None:
1379-
return None, None
1380-
1381-
tag = image_uri.split(":")[-1]
1382-
1383-
return tag, tag.split("-")[0]

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 59 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@
3838
SkipTuningComboException,
3939
)
4040
from sagemaker.serve.utils.optimize_utils import (
41-
_is_compatible_with_optimization_job,
4241
_extract_model_source,
4342
_update_environment_variables,
43+
_extract_speculative_draft_model_provider,
44+
_is_image_compatible_with_optimization_job,
45+
_validate_optimization_inputs,
4446
)
4547
from sagemaker.serve.utils.predictors import (
4648
DjlLocalModePredictor,
@@ -628,7 +630,7 @@ def _build_for_jumpstart(self):
628630

629631
def _optimize_for_jumpstart(
630632
self,
631-
output_path: str,
633+
output_path: Optional[str] = None,
632634
instance_type: Optional[str] = None,
633635
role: Optional[str] = None,
634636
tags: Optional[Tags] = None,
@@ -645,7 +647,7 @@ def _optimize_for_jumpstart(
645647
"""Runs a model optimization job.
646648
647649
Args:
648-
output_path (str): Specifies where to store the compiled/quantized model.
650+
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
649651
instance_type (Optional[str]): Target deployment instance type that
650652
the model is optimized for.
651653
role (Optional[str]): Execution role. Defaults to ``None``.
@@ -673,40 +675,30 @@ def _optimize_for_jumpstart(
673675
"""
674676
if self._is_gated_model() and accept_eula is not True:
675677
raise ValueError(
676-
f"ValueError: Model '{self.model}' "
677-
f"requires accepting end-user license agreement (EULA)."
678+
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
678679
)
679680

681+
_validate_optimization_inputs(
682+
output_path, instance_type, quantization_config, compilation_config
683+
)
684+
680685
optimization_env_vars = None
681686
pysdk_model_env_vars = None
682687
model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula)
683688

684689
if speculative_decoding_config:
685690
self._set_additional_model_source(speculative_decoding_config)
686-
optimization_env_vars = self.pysdk_model.deployment_config.get("DeploymentArgs").get(
687-
"Environment"
688-
)
691+
optimization_env_vars = self.pysdk_model.deployment_config.get(
692+
"DeploymentArgs", {}
693+
).get("Environment")
689694
else:
690-
image_uri = None
691-
if quantization_config and quantization_config.get("Image"):
692-
image_uri = quantization_config.get("Image")
693-
elif compilation_config and compilation_config.get("Image"):
694-
image_uri = compilation_config.get("Image")
695-
instance_type = (
696-
instance_type
697-
or self.pysdk_model.deployment_config.get("DeploymentArgs").get("InstanceType")
698-
or _get_nb_instance()
699-
)
700-
if not _is_compatible_with_optimization_job(instance_type, image_uri):
701-
deployment_config = self._find_compatible_deployment_config(None)
702-
if deployment_config:
703-
optimization_env_vars = deployment_config.get("DeploymentArgs").get(
704-
"Environment"
705-
)
706-
self.pysdk_model.set_deployment_config(
707-
config_name=deployment_config.get("DeploymentConfigName"),
708-
instance_type=deployment_config.get("InstanceType"),
709-
)
695+
deployment_config = self._find_compatible_deployment_config(None)
696+
if deployment_config:
697+
optimization_env_vars = deployment_config.get("DeploymentArgs").get("Environment")
698+
self.pysdk_model.set_deployment_config(
699+
config_name=deployment_config.get("DeploymentConfigName"),
700+
instance_type=deployment_config.get("InstanceType"),
701+
)
710702

711703
optimization_env_vars = _update_environment_variables(optimization_env_vars, env_vars)
712704

@@ -736,7 +728,7 @@ def _optimize_for_jumpstart(
736728
}
737729

738730
if optimization_env_vars:
739-
create_optimization_job_args["Environment"] = optimization_env_vars
731+
create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars
740732
if max_runtime_in_sec:
741733
create_optimization_job_args["StoppingCondition"] = {
742734
"MaxRuntimeInSeconds": max_runtime_in_sec
@@ -766,18 +758,26 @@ def _is_gated_model(self, model=None) -> bool:
766758
return "private" in s3_uri
767759

768760
def _set_additional_model_source(
769-
self, speculative_decoding_config: Optional[Dict[str, Any]] = None
761+
self,
762+
speculative_decoding_config: Optional[Dict[str, Any]] = None,
763+
accept_eula: Optional[bool] = None,
770764
) -> None:
771765
"""Set Additional Model Source to ``this`` model.
772766
773767
Args:
774768
speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config.
769+
accept_eula (Optional[bool]): For models that require a Model Access Config.
775770
"""
776771
if speculative_decoding_config:
777-
model_provider: str = speculative_decoding_config["ModelProvider"]
772+
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
778773

779774
if model_provider.lower() == "sagemaker":
780-
if not self._is_speculation_enabled(self.pysdk_model.deployment_config):
775+
if (
776+
self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
777+
"AdditionalDataSources"
778+
)
779+
is None
780+
):
781781
deployment_config = self._find_compatible_deployment_config(
782782
speculative_decoding_config
783783
)
@@ -786,21 +786,30 @@ def _set_additional_model_source(
786786
config_name=deployment_config.get("DeploymentConfigName"),
787787
instance_type=deployment_config.get("InstanceType"),
788788
)
789-
self.pysdk_model.add_tags(
790-
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"},
791-
)
792789
else:
793790
raise ValueError(
794791
"Cannot find deployment config compatible for optimization job."
795792
)
793+
794+
self.pysdk_model.add_tags(
795+
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "sagemaker"},
796+
)
796797
else:
797798
s3_uri = speculative_decoding_config.get("ModelSource")
798799
if not s3_uri:
799800
raise ValueError("Custom S3 Uri cannot be none.")
800801

801-
self.pysdk_model.additional_model_data_sources["speculative_decoding"][0][
802-
"s3_data_source"
803-
]["s3_uri"] = s3_uri
802+
# TODO: Set correct channel name.
803+
additional_model_data_source = {
804+
"ChannelName": "DraftModelName",
805+
"S3DataSource": {"S3Uri": s3_uri},
806+
}
807+
if accept_eula:
808+
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {
809+
"ACCEPT_EULA": True
810+
}
811+
812+
self.pysdk_model.additional_model_data_sources = [additional_model_data_source]
804813
self.pysdk_model.add_tags(
805814
{"key": Tag.SPECULATIVE_DRAFT_MODL_PROVIDER, "value": "customer"},
806815
)
@@ -816,36 +825,20 @@ def _find_compatible_deployment_config(
816825
Returns:
817826
Optional[Dict[str, Any]]: A compatible model deployment config for optimization job.
818827
"""
828+
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
819829
for deployment_config in self.pysdk_model.list_deployment_configs():
820-
instance_type = deployment_config.get("deployment_config").get("InstanceType")
821-
image_uri = deployment_config.get("deployment_config").get("ImageUri")
822-
823-
if _is_compatible_with_optimization_job(instance_type, image_uri):
824-
if not speculative_decoding_config:
825-
return deployment_config
830+
image_uri = deployment_config.get("deployment_config", {}).get("ImageUri")
826831

827-
if self._is_speculation_enabled(deployment_config):
832+
if _is_image_compatible_with_optimization_job(image_uri):
833+
if (
834+
model_provider == "sagemaker"
835+
and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources")
836+
) or model_provider == "custom":
828837
return deployment_config
829838

830-
return None
831-
832-
def _is_speculation_enabled(self, deployment_config: Optional[Dict[str, Any]]) -> bool:
833-
"""Checks whether speculative is enabled for the given deployment config.
839+
# There's no matching config from jumpstart to add sagemaker draft model location
840+
if model_provider == "sagemaker":
841+
return None
834842

835-
Args:
836-
deployment_config (Dict[str, Any]): A deployment config.
837-
838-
Returns:
839-
bool: Whether speculative is enabled for this deployment config.
840-
"""
841-
if deployment_config is None:
842-
return False
843-
844-
acceleration_configs = deployment_config.get("AccelerationConfigs")
845-
if acceleration_configs:
846-
for acceleration_config in acceleration_configs:
847-
if acceleration_config.get(
848-
"type", "default"
849-
).lower() == "speculative" and acceleration_config.get("enabled"):
850-
return True
851-
return False
843+
# fall back to the default jumpstart model deployment config for optimization job
844+
return self.pysdk_model.deployment_config

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from sagemaker import Model
2121
from sagemaker.enums import Tag
22-
from sagemaker.fw_utils import _is_gpu_instance
2322

2423

2524
logger = logging.getLogger(__name__)
@@ -42,30 +41,19 @@ def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool:
4241
return False
4342

4443

45-
def _is_compatible_with_optimization_job(
46-
instance_type: Optional[str], image_uri: Optional[str]
47-
) -> bool:
44+
def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool:
4845
"""Checks whether an instance is compatible with an optimization job.
4946
5047
Args:
51-
instance_type (str): The instance type used for the compilation job.
5248
image_uri (str): The image URI of the optimization job.
5349
5450
Returns:
5551
bool: Whether the given instance type is compatible with an optimization job.
5652
"""
57-
if not instance_type:
58-
return False
59-
60-
compatible_image = True
61-
if image_uri:
62-
compatible_image = "djl-inference:" in image_uri and (
63-
"-lmi" in image_uri or "-neuronx-" in image_uri
64-
)
65-
66-
return (
67-
_is_gpu_instance(instance_type) or _is_inferentia_or_trainium(instance_type)
68-
) and compatible_image
53+
# TODO: Use specific container type instead.
54+
if image_uri is None:
55+
return True
56+
return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri)
6957

7058

7159
def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model:
@@ -89,28 +77,6 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -
8977
return pysdk_model
9078

9179

92-
def _is_speculation_enabled(deployment_config: Optional[Dict[str, Any]]) -> bool:
93-
"""Checks whether speculation is enabled for this deployment config.
94-
95-
Args:
96-
deployment_config (Dict[str, Any]): A deployment config.
97-
98-
Returns:
99-
bool: Whether the speculation is enabled for this deployment config.
100-
"""
101-
if deployment_config is None:
102-
return False
103-
104-
acceleration_configs = deployment_config.get("AccelerationConfigs")
105-
if acceleration_configs:
106-
for acceleration_config in acceleration_configs:
107-
if acceleration_config.get("type").lower() == "speculation" and acceleration_config.get(
108-
"enabled"
109-
):
110-
return True
111-
return False
112-
113-
11480
def _extract_model_source(
11581
model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool]
11682
) -> Optional[Dict[str, Any]]:
@@ -129,7 +95,6 @@ def _extract_model_source(
12995
if isinstance(s3_uri, dict):
13096
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
13197

132-
# Todo: Inject fine-tune data source
13398
model_source = {"S3": {"S3Uri": s3_uri}}
13499
if accept_eula:
135100
model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True}
@@ -154,3 +119,61 @@ def _update_environment_variables(
154119
else:
155120
env = new_env
156121
return env
122+
123+
124+
def _extract_speculative_draft_model_provider(
125+
speculative_decoding_config: Optional[Dict] = None,
126+
) -> Optional[str]:
127+
"""Extracts speculative draft model provider from speculative decoding config.
128+
129+
Args:
130+
speculative_decoding_config (Optional[Dict]): A speculative decoding config.
131+
132+
Returns:
133+
Optional[str]: The speculative draft model provider.
134+
"""
135+
if speculative_decoding_config is None:
136+
return None
137+
138+
if speculative_decoding_config.get(
139+
"ModelProvider"
140+
) == "Custom" or speculative_decoding_config.get("ModelSource"):
141+
return "custom"
142+
143+
return "sagemaker"
144+
145+
146+
def _validate_optimization_inputs(
147+
output_path: Optional[str] = None,
148+
instance_type: Optional[str] = None,
149+
quantization_config: Optional[Dict] = None,
150+
compilation_config: Optional[Dict] = None,
151+
) -> None:
152+
"""Validates optimization inputs.
153+
154+
Args:
155+
output_path (Optional[str]): The output path.
156+
instance_type (Optional[str]): The instance type.
157+
quantization_config (Optional[Dict]): The quantization config.
158+
compilation_config (Optional[Dict]): The compilation config.
159+
160+
Raises:
161+
ValueError: If an optimization input is invalid.
162+
"""
163+
if quantization_config and compilation_config:
164+
raise ValueError("Quantization config and compilation config are mutually exclusive.")
165+
166+
instance_type_msg = "Please provide an instance type for %s optimization job."
167+
output_path_msg = "Please provide an output path for %s optimization job."
168+
169+
if quantization_config:
170+
if not instance_type:
171+
raise ValueError(instance_type_msg.format("quantization"))
172+
if not output_path:
173+
raise ValueError(output_path_msg.format("quantization"))
174+
175+
if compilation_config:
176+
if not instance_type:
177+
raise ValueError(instance_type_msg.format("compilation"))
178+
if not output_path:
179+
raise ValueError(output_path_msg.format("compilation"))

0 commit comments

Comments
 (0)