Skip to content

Commit 6687c56

Browse files
makungaj1Jonathan Makunga
andauthored
Fixing bugs (aws#1506)
* Fixing bugs * Refactoring * Increase coverage * Fix UT * Fix UT * Increase coverage * Fix UT * Refactoring * Fix UT --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 26c8696 commit 6687c56

File tree

8 files changed

+515
-56
lines changed

8 files changed

+515
-56
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -714,12 +714,9 @@ def _optimize_for_jumpstart(
714714
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
715715
)
716716

717-
optimization_env_vars = env_vars
718-
pysdk_model_env_vars = env_vars
719-
717+
pysdk_model_env_vars = dict()
720718
if compilation_config:
721-
neuron_env = self._get_neuron_model_env_vars(instance_type)
722-
optimization_env_vars = _update_environment_variables(neuron_env, optimization_env_vars)
719+
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
723720

724721
if speculative_decoding_config:
725722
self._set_additional_model_source(speculative_decoding_config)
@@ -730,28 +727,34 @@ def _optimize_for_jumpstart(
730727
config_name=deployment_config.get("DeploymentConfigName"),
731728
instance_type=deployment_config.get("InstanceType"),
732729
)
730+
pysdk_model_env_vars = self.pysdk_model.env
733731

734732
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
735-
optimization_config, env = _extract_optimization_config_and_env(
733+
optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars)
734+
735+
optimization_config, override_env = _extract_optimization_config_and_env(
736736
quantization_config, compilation_config
737737
)
738-
pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env)
739738

740739
output_config = {"S3OutputLocation": output_path}
741740
if kms_key:
742741
output_config["KmsKeyId"] = kms_key
743-
if not instance_type:
744-
instance_type = self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
745-
"InstanceType", _get_nb_instance()
746-
)
742+
743+
deployment_config_instance_type = (
744+
self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get("InstanceType")
745+
if self.pysdk_model.deployment_config
746+
else None
747+
)
748+
self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance()
749+
self.role_arn = role_arn or self.role_arn
747750

748751
create_optimization_job_args = {
749752
"OptimizationJobName": job_name,
750753
"ModelSource": model_source,
751-
"DeploymentInstanceType": instance_type,
754+
"DeploymentInstanceType": self.instance_type,
752755
"OptimizationConfigs": [optimization_config],
753756
"OutputConfig": output_config,
754-
"RoleArn": role_arn,
757+
"RoleArn": self.role_arn,
755758
}
756759

757760
if optimization_env_vars:
@@ -765,8 +768,6 @@ def _optimize_for_jumpstart(
765768
if vpc_config:
766769
create_optimization_job_args["VpcConfig"] = vpc_config
767770

768-
if pysdk_model_env_vars:
769-
self.pysdk_model.env.update(pysdk_model_env_vars)
770771
if accept_eula:
771772
self.pysdk_model.accept_eula = accept_eula
772773
if isinstance(self.pysdk_model.model_data, dict):
@@ -775,6 +776,9 @@ def _optimize_for_jumpstart(
775776
}
776777

777778
if quantization_config or compilation_config:
779+
self.pysdk_model.env = _update_environment_variables(
780+
optimization_env_vars, override_env
781+
)
778782
return create_optimization_job_args
779783
return None
780784

@@ -810,9 +814,13 @@ def _set_additional_model_source(
810814
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
811815

812816
if model_provider == "sagemaker":
813-
additional_model_data_sources = self.pysdk_model.deployment_config.get(
814-
"DeploymentArgs", {}
815-
).get("AdditionalDataSources")
817+
additional_model_data_sources = (
818+
self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
819+
"AdditionalDataSources"
820+
)
821+
if self.pysdk_model.deployment_config
822+
else None
823+
)
816824
if additional_model_data_sources is None:
817825
deployment_config = self._find_compatible_deployment_config(
818826
speculative_decoding_config

src/sagemaker/serve/builder/model_builder.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
from sagemaker.serve.utils.optimize_utils import (
7272
_generate_optimized_model,
7373
_generate_model_source,
74-
_update_environment_variables,
7574
_extract_optimization_config_and_env,
7675
_is_s3_uri,
7776
_normalize_local_model_path,
@@ -840,8 +839,7 @@ def build( # pylint: disable=R0911
840839
if role_arn:
841840
self.role_arn = role_arn
842841

843-
if not self.sagemaker_session:
844-
self.sagemaker_session = sagemaker_session or Session()
842+
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
845843

846844
self.sagemaker_session.settings._local_download_dir = self.model_path
847845

@@ -1111,8 +1109,6 @@ def optimize(self, *args, **kwargs) -> Model:
11111109
Returns:
11121110
Model: A deployable ``Model`` object.
11131111
"""
1114-
if self.mode != Mode.SAGEMAKER_ENDPOINT:
1115-
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
11161112

11171113
# need to get telemetry_opt_out info before telemetry decorator is called
11181114
self.serve_settings = self._get_serve_setting()
@@ -1174,6 +1170,9 @@ def _model_builder_optimize_wrapper(
11741170
speculative_decoding_config
11751171
)
11761172

1173+
if self.mode != Mode.SAGEMAKER_ENDPOINT:
1174+
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
1175+
11771176
if quantization_config and compilation_config:
11781177
raise ValueError("Quantization config and compilation config are mutually exclusive.")
11791178

@@ -1273,39 +1272,39 @@ def _optimize_for_hf(
12731272
logger.info("Overwriting model server to DJL.")
12741273
self.model_server = ModelServer.DJL_SERVING
12751274

1276-
optimization_env_vars = env_vars
1277-
pysdk_model_env_vars = env_vars
1275+
self.role_arn = role_arn or self.role_arn
1276+
self.instance_type = instance_type or self.instance_type
12781277

12791278
if quantization_config or compilation_config:
1280-
self.instance_type = instance_type or self.instance_type
1279+
create_optimization_job_args = {
1280+
"OptimizationJobName": job_name,
1281+
"DeploymentInstanceType": self.instance_type,
1282+
"RoleArn": self.role_arn,
1283+
}
1284+
1285+
if env_vars:
1286+
self.pysdk_model.env.update(env_vars)
1287+
create_optimization_job_args["OptimizationEnvironment"] = env_vars
12811288

12821289
self._optimize_prepare_for_hf()
12831290
model_source = _generate_model_source(self.pysdk_model.model_data, False)
1291+
create_optimization_job_args["ModelSource"] = model_source
12841292

12851293
self.pysdk_model = _custom_speculative_decoding(
12861294
self.pysdk_model, speculative_decoding_config, False
12871295
)
12881296

1289-
optimization_config, env = _extract_optimization_config_and_env(
1297+
optimization_config, override_env = _extract_optimization_config_and_env(
12901298
quantization_config, compilation_config
12911299
)
1292-
pysdk_model_env_vars = _update_environment_variables(pysdk_model_env_vars, env)
1300+
create_optimization_job_args["OptimizationConfigs"] = [optimization_config]
1301+
self.pysdk_model.env.update(override_env)
12931302

12941303
output_config = {"S3OutputLocation": output_path}
12951304
if kms_key:
12961305
output_config["KmsKeyId"] = kms_key
1306+
create_optimization_job_args["OutputConfig"] = output_config
12971307

1298-
create_optimization_job_args = {
1299-
"OptimizationJobName": job_name,
1300-
"ModelSource": model_source,
1301-
"DeploymentInstanceType": self.instance_type,
1302-
"OptimizationConfigs": [optimization_config],
1303-
"OutputConfig": output_config,
1304-
"RoleArn": role_arn,
1305-
}
1306-
1307-
if optimization_env_vars:
1308-
create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars
13091308
if max_runtime_in_sec:
13101309
create_optimization_job_args["StoppingCondition"] = {
13111310
"MaxRuntimeInSeconds": max_runtime_in_sec
@@ -1315,8 +1314,10 @@ def _optimize_for_hf(
13151314
if vpc_config:
13161315
create_optimization_job_args["VpcConfig"] = vpc_config
13171316

1318-
if pysdk_model_env_vars:
1319-
self.pysdk_model.env.update(pysdk_model_env_vars)
1317+
# HF_MODEL_ID needs not to be present, otherwise,
1318+
# HF model artifacts will be re-downloaded during deployment
1319+
if "HF_MODEL_ID" in self.pysdk_model.env:
1320+
del self.pysdk_model.env["HF_MODEL_ID"]
13201321

13211322
return create_optimization_job_args
13221323
return None

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,11 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -
5454
recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get(
5555
"RecommendedInferenceImage"
5656
)
57-
optimized_environment = optimization_response.get("OptimizationEnvironment")
5857
s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation")
5958
deployment_instance_type = optimization_response.get("DeploymentInstanceType")
6059

6160
if recommended_image_uri:
6261
pysdk_model.image_uri = recommended_image_uri
63-
if optimized_environment:
64-
if pysdk_model.env:
65-
pysdk_model.env.update(optimized_environment)
66-
else:
67-
pysdk_model.env = optimized_environment
6862
if s3_uri:
6963
pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri
7064
if deployment_instance_type:

tests/unit/sagemaker/serve/builder/test_djl_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_tune_for_djl_local_container_deep_ping_ex(
188188
tuned_model = model.tune()
189189
assert tuned_model.env == mock_default_configs
190190

191+
@patch("sagemaker.serve.builder.djl_builder._get_model_config_properties_from_hf")
191192
@patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None)
192193
@patch(
193194
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
@@ -211,7 +212,10 @@ def test_tune_for_djl_local_container_load_ex(
211212
mock_get_ram_usage_mb,
212213
mock_is_jumpstart_model,
213214
mock_telemetry,
215+
mock_get_model_config_properties_from_hf,
214216
):
217+
mock_get_model_config_properties_from_hf.return_value = {}
218+
215219
builder = ModelBuilder(
216220
model=mock_model_id,
217221
schema_builder=mock_schema_builder,

0 commit comments

Comments
 (0)