Skip to content

Commit 7993b77

Browse files
makungaj1Jonathan Makunga
andauthored
Fix public optimize api signature (aws#1507)
* Fix public optimize api signature * JS Compilation fix * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 6687c56 commit 7993b77

File tree

4 files changed

+103
-26
lines changed

4 files changed

+103
-26
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_is_optimized,
4848
_custom_speculative_decoding,
4949
SPECULATIVE_DRAFT_MODEL,
50+
_is_inferentia_or_trainium,
5051
)
5152
from sagemaker.serve.utils.predictors import (
5253
DjlLocalModePredictor,
@@ -714,10 +715,25 @@ def _optimize_for_jumpstart(
714715
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
715716
)
716717

718+
is_compilation = (quantization_config is None) and (
719+
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
720+
)
721+
717722
pysdk_model_env_vars = dict()
718-
if compilation_config:
723+
if is_compilation:
719724
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
720725

726+
optimization_config, override_env = _extract_optimization_config_and_env(
727+
quantization_config, compilation_config
728+
)
729+
if not optimization_config and is_compilation:
730+
override_env = override_env or pysdk_model_env_vars
731+
optimization_config = {
732+
"ModelCompilationConfig": {
733+
"OverrideEnvironment": override_env,
734+
}
735+
}
736+
721737
if speculative_decoding_config:
722738
self._set_additional_model_source(speculative_decoding_config)
723739
else:
@@ -732,10 +748,6 @@ def _optimize_for_jumpstart(
732748
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
733749
optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars)
734750

735-
optimization_config, override_env = _extract_optimization_config_and_env(
736-
quantization_config, compilation_config
737-
)
738-
739751
output_config = {"S3OutputLocation": output_path}
740752
if kms_key:
741753
output_config["KmsKeyId"] = kms_key
@@ -775,7 +787,7 @@ def _optimize_for_jumpstart(
775787
"AcceptEula": True
776788
}
777789

778-
if quantization_config or compilation_config:
790+
if quantization_config or is_compilation:
779791
self.pysdk_model.env = _update_environment_variables(
780792
optimization_env_vars, override_env
781793
)

src/sagemaker/serve/builder/model_builder.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,25 +1083,47 @@ def _try_fetch_gpu_info(self):
10831083
f"Unable to determine single GPU size for instance: [{self.instance_type}]"
10841084
)
10851085

1086-
def optimize(self, *args, **kwargs) -> Model:
1087-
"""Runs a model optimization job.
1086+
def optimize(
1087+
self,
1088+
output_path: Optional[str] = None,
1089+
instance_type: Optional[str] = None,
1090+
role_arn: Optional[str] = None,
1091+
tags: Optional[Tags] = None,
1092+
job_name: Optional[str] = None,
1093+
accept_eula: Optional[bool] = None,
1094+
quantization_config: Optional[Dict] = None,
1095+
compilation_config: Optional[Dict] = None,
1096+
speculative_decoding_config: Optional[Dict] = None,
1097+
env_vars: Optional[Dict] = None,
1098+
vpc_config: Optional[Dict] = None,
1099+
kms_key: Optional[str] = None,
1100+
max_runtime_in_sec: Optional[int] = 36000,
1101+
sagemaker_session: Optional[Session] = None,
1102+
) -> Model:
1103+
"""Create an optimized deployable ``Model`` instance with ``ModelBuilder``.
10881104
10891105
Args:
1090-
instance_type (Optional[str]): Target deployment instance type that the
1091-
model is optimized for.
1092-
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
1093-
role_arn (Optional[str]): Execution role. Defaults to ``None``.
1106+
output_path (str): Specifies where to store the compiled/quantized model.
1107+
instance_type (str): Target deployment instance type that the model is optimized for.
1108+
role_arn (Optional[str]): Execution role arn. Defaults to ``None``.
10941109
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
10951110
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
1111+
accept_eula (bool): For models that require a Model Access Config, specify True or
1112+
False to indicate whether model terms of use have been accepted.
1113+
The `accept_eula` value must be explicitly defined as `True` in order to
1114+
accept the end-user license agreement (EULA) that some
1115+
models require. (Default: None).
10961116
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
10971117
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1118+
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1119+
Defaults to ``None``
10981120
env_vars (Optional[Dict]): Additional environment variables to run the optimization
10991121
container. Defaults to ``None``.
11001122
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
11011123
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
11021124
to S3. Defaults to ``None``.
11031125
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
1104-
``None``.
1126+
36000 seconds.
11051127
sagemaker_session (Optional[Session]): Session object which manages interactions
11061128
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
11071129
function creates one using the default AWS configuration chain.
@@ -1113,7 +1135,22 @@ def optimize(self, *args, **kwargs) -> Model:
11131135
# need to get telemetry_opt_out info before telemetry decorator is called
11141136
self.serve_settings = self._get_serve_setting()
11151137

1116-
return self._model_builder_optimize_wrapper(*args, **kwargs)
1138+
return self._model_builder_optimize_wrapper(
1139+
output_path=output_path,
1140+
instance_type=instance_type,
1141+
role_arn=role_arn,
1142+
tags=tags,
1143+
job_name=job_name,
1144+
accept_eula=accept_eula,
1145+
quantization_config=quantization_config,
1146+
compilation_config=compilation_config,
1147+
speculative_decoding_config=speculative_decoding_config,
1148+
env_vars=env_vars,
1149+
vpc_config=vpc_config,
1150+
kms_key=kms_key,
1151+
max_runtime_in_sec=max_runtime_in_sec,
1152+
sagemaker_session=sagemaker_session,
1153+
)
11171154

11181155
@_capture_telemetry("optimize")
11191156
def _model_builder_optimize_wrapper(
@@ -1178,10 +1215,8 @@ def _model_builder_optimize_wrapper(
11781215

11791216
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
11801217

1181-
if instance_type:
1182-
self.instance_type = instance_type
1183-
if role_arn:
1184-
self.role_arn = role_arn
1218+
self.instance_type = instance_type or self.instance_type
1219+
self.role_arn = role_arn or self.role_arn
11851220

11861221
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
11871222
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
@@ -1266,7 +1301,7 @@ def _optimize_for_hf(
12661301
``None``.
12671302
12681303
Returns:
1269-
Dict[str, Any]: Model optimization job input arguments.
1304+
Optional[Dict[str, Any]]: Model optimization job input arguments.
12701305
"""
12711306
if self.model_server != ModelServer.DJL_SERVING:
12721307
logger.info("Overwriting model server to DJL.")
@@ -1275,6 +1310,10 @@ def _optimize_for_hf(
12751310
self.role_arn = role_arn or self.role_arn
12761311
self.instance_type = instance_type or self.instance_type
12771312

1313+
self.pysdk_model = _custom_speculative_decoding(
1314+
self.pysdk_model, speculative_decoding_config, False
1315+
)
1316+
12781317
if quantization_config or compilation_config:
12791318
create_optimization_job_args = {
12801319
"OptimizationJobName": job_name,
@@ -1290,10 +1329,6 @@ def _optimize_for_hf(
12901329
model_source = _generate_model_source(self.pysdk_model.model_data, False)
12911330
create_optimization_job_args["ModelSource"] = model_source
12921331

1293-
self.pysdk_model = _custom_speculative_decoding(
1294-
self.pysdk_model, speculative_decoding_config, False
1295-
)
1296-
12971332
optimization_config, override_env = _extract_optimization_config_and_env(
12981333
quantization_config, compilation_config
12991334
)

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@
2626
SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources"
2727

2828

29+
def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool:
30+
"""Checks whether an instance is compatible with Inferentia.
31+
32+
Args:
33+
instance_type (str): The instance type used for the compilation job.
34+
35+
Returns:
36+
bool: Whether the given instance type is Inferentia or Trainium.
37+
"""
38+
if isinstance(instance_type, str):
39+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
40+
if match:
41+
if match[1].startswith("inf") or match[1].startswith("trn"):
42+
return True
43+
return False
44+
45+
2946
def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool:
3047
"""Checks whether an instance is compatible with an optimization job.
3148
@@ -169,11 +186,11 @@ def _extracts_and_validates_speculative_model_source(
169186
Raises:
170187
ValueError: If model source is none.
171188
"""
172-
s3_uri: str = speculative_decoding_config.get("ModelSource")
189+
model_source: str = speculative_decoding_config.get("ModelSource")
173190

174-
if not s3_uri:
191+
if not model_source:
175192
raise ValueError("ModelSource must be provided in speculative decoding config.")
176-
return s3_uri
193+
return model_source
177194

178195

179196
def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str:

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_normalize_local_model_path,
3232
_is_optimized,
3333
_custom_speculative_decoding,
34+
_is_inferentia_or_trainium,
3435
)
3536

3637
mock_optimization_job_output = {
@@ -81,6 +82,18 @@
8182
}
8283

8384

85+
@pytest.mark.parametrize(
86+
"instance, expected",
87+
[
88+
("ml.trn1.2xlarge", True),
89+
("ml.inf2.xlarge", True),
90+
("ml.c7gd.4xlarge", False),
91+
],
92+
)
93+
def test_is_inferentia_or_trainium(instance, expected):
94+
assert _is_inferentia_or_trainium(instance) == expected
95+
96+
8497
@pytest.mark.parametrize(
8598
"image_uri, expected",
8699
[

0 commit comments

Comments
 (0)