Skip to content

Commit 9a3f6ca

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Bug bash fixes (aws#1492)
* HF Optimized * Revert "HF Optimized" * MB HF Optimize support * Refactoring * HF only s3 upload if optimize * reuse role if provided in MB * Refactoring * New requirements * Draft * Refactoring * Refactoring * Bug Bash fixes * UT * UT * Fix for parsing optimization output * Tag fix * UT * UT --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 99345d8 commit 9a3f6ca

17 files changed

+337
-101
lines changed

src/sagemaker/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
format_tags,
6868
Tags,
6969
_resolve_routing_config,
70+
_validate_new_tags,
7071
)
7172
from sagemaker.async_inference import AsyncInferenceConfig
7273
from sagemaker.predictor_async import AsyncPredictor
@@ -412,10 +413,7 @@ def add_tags(self, tags: Tags) -> None:
412413
Args:
413414
tags (Tags): Tags to add.
414415
"""
415-
if self._tags and tags:
416-
self._tags.update(tags)
417-
else:
418-
self._tags = tags
416+
self._tags = _validate_new_tags(tags, self._tags)
419417

420418
@runnable_by_pipeline
421419
def register(

src/sagemaker/serve/builder/djl_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self):
100100
self.env_vars = None
101101
self.nb_instance_type = None
102102
self.ram_usage_model_load = None
103+
self.role_arn = None
103104

104105
@abstractmethod
105106
def _prepare_for_mode(self):
@@ -499,4 +500,8 @@ def _build_for_djl(self):
499500

500501
self.pysdk_model = self._build_for_hf_djl()
501502
self.pysdk_model.tune = self._tune_for_hf_djl
503+
if self.role_arn:
504+
self.pysdk_model.role = self.role_arn
505+
if self.sagemaker_session:
506+
self.pysdk_model.sagemaker_session = self.sagemaker_session
502507
return self.pysdk_model

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
SkipTuningComboException,
3939
)
4040
from sagemaker.serve.utils.optimize_utils import (
41-
_extract_model_source,
41+
_generate_model_source,
4242
_update_environment_variables,
4343
_extract_speculative_draft_model_provider,
4444
_is_image_compatible_with_optimization_job,
45-
_validate_optimization_inputs,
45+
_extracts_and_validates_speculative_model_source,
46+
_generate_channel_name,
47+
_generate_additional_model_data_sources,
4648
)
4749
from sagemaker.serve.utils.predictors import (
4850
DjlLocalModePredictor,
@@ -110,6 +112,7 @@ def __init__(self):
110112
self.ram_usage_model_load = None
111113
self.model_hub = None
112114
self.model_metadata = None
115+
self.role_arn = None
113116
self.is_fine_tuned = None
114117
self.is_gated = None
115118

@@ -544,7 +547,7 @@ def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> T
544547
)
545548
pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path
546549
pysdk_model.add_tags(
547-
{"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}
550+
{"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": fine_tuning_model_path}
548551
)
549552
logger.info(
550553
"FINE_TUNING_MODEL_PATH detected. Using fine-tuned model found in %s.",
@@ -633,6 +636,10 @@ def _build_for_jumpstart(self):
633636
"with djl-inference, tgi-inference, or mms-inference container."
634637
)
635638

639+
if self.role_arn:
640+
self.pysdk_model.role = self.role_arn
641+
if self.sagemaker_session:
642+
self.pysdk_model.sagemaker_session = self.sagemaker_session
636643
return self.pysdk_model
637644

638645
def _optimize_for_jumpstart(
@@ -650,7 +657,7 @@ def _optimize_for_jumpstart(
650657
vpc_config: Optional[Dict] = None,
651658
kms_key: Optional[str] = None,
652659
max_runtime_in_sec: Optional[int] = None,
653-
) -> Dict[str, Any]:
660+
) -> Optional[Dict[str, Any]]:
654661
"""Runs a model optimization job.
655662
656663
Args:
@@ -685,13 +692,9 @@ def _optimize_for_jumpstart(
685692
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
686693
)
687694

688-
_validate_optimization_inputs(
689-
output_path, instance_type, quantization_config, compilation_config
690-
)
691-
692695
optimization_env_vars = None
693696
pysdk_model_env_vars = None
694-
model_source = _extract_model_source(self.pysdk_model.model_data, accept_eula)
697+
model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula)
695698

696699
if speculative_decoding_config:
697700
self._set_additional_model_source(speculative_decoding_config)
@@ -745,8 +748,12 @@ def _optimize_for_jumpstart(
745748
if vpc_config:
746749
create_optimization_job_args["VpcConfig"] = vpc_config
747750

748-
self.pysdk_model.env.update(pysdk_model_env_vars)
749-
return create_optimization_job_args
751+
if pysdk_model_env_vars:
752+
self.pysdk_model.env.update(pysdk_model_env_vars)
753+
754+
if quantization_config or compilation_config:
755+
return create_optimization_job_args
756+
return None
750757

751758
def _is_gated_model(self, model=None) -> bool:
752759
"""Determine if ``this`` Model is Gated
@@ -779,14 +786,13 @@ def _set_additional_model_source(
779786
"""
780787
if speculative_decoding_config:
781788
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
789+
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
782790

783791
if model_provider.lower() == "sagemaker":
784-
if (
785-
self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get(
786-
"AdditionalDataSources"
787-
)
788-
is None
789-
):
792+
additional_model_data_sources = self.pysdk_model.deployment_config.get(
793+
"DeploymentArgs", {}
794+
).get("AdditionalDataSources")
795+
if additional_model_data_sources is None:
790796
deployment_config = self._find_compatible_deployment_config(
791797
speculative_decoding_config
792798
)
@@ -801,28 +807,26 @@ def _set_additional_model_source(
801807
)
802808

803809
self.pysdk_model.add_tags(
804-
{"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "sagemaker"},
810+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
805811
)
806812
else:
807-
s3_uri = speculative_decoding_config.get("ModelSource")
808-
if not s3_uri:
809-
raise ValueError("Custom S3 Uri cannot be none.")
810-
811-
# TODO: Set correct channel name.
812-
additional_model_data_source = {
813-
"ChannelName": "DraftModelName",
814-
"S3DataSource": {"S3Uri": s3_uri},
815-
}
816-
if accept_eula:
817-
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {
818-
"ACCEPT_EULA": True
819-
}
820-
821-
self.pysdk_model.additional_model_data_sources = [additional_model_data_source]
813+
s3_uri = _extracts_and_validates_speculative_model_source(
814+
speculative_decoding_config
815+
)
816+
817+
self.pysdk_model.additional_model_data_sources = (
818+
_generate_additional_model_data_sources(s3_uri, channel_name, accept_eula)
819+
)
822820
self.pysdk_model.add_tags(
823-
{"key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "value": "customer"},
821+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"},
824822
)
825823

824+
speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}"
825+
self.pysdk_model.env = _update_environment_variables(
826+
self.pysdk_model.env,
827+
{"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model},
828+
)
829+
826830
def _find_compatible_deployment_config(
827831
self, speculative_decoding_config: Optional[Dict] = None
828832
) -> Optional[Dict[str, Any]]:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@
6363
from sagemaker.serve.utils import task
6464
from sagemaker.serve.utils.exceptions import TaskNotFoundException
6565
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
66-
from sagemaker.serve.utils.optimize_utils import _generate_optimized_model
66+
from sagemaker.serve.utils.optimize_utils import (
67+
_generate_optimized_model,
68+
_validate_optimization_inputs,
69+
)
6770
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
6871
from sagemaker.serve.utils.hardware_detector import (
6972
_get_gpu_info,
@@ -87,7 +90,9 @@
8790
)
8891
from sagemaker.utils import Tags
8992
from sagemaker.workflow.entities import PipelineVariable
90-
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
93+
from sagemaker.huggingface.llm_utils import (
94+
get_huggingface_model_metadata,
95+
)
9196

9297
logger = logging.getLogger(__name__)
9398

@@ -383,7 +388,7 @@ def _get_serve_setting(self):
383388
sagemaker_session=self.sagemaker_session,
384389
)
385390

386-
def _prepare_for_mode(self):
391+
def _prepare_for_mode(self, should_upload_artifacts: bool = False):
387392
"""Placeholder docstring"""
388393
# TODO: move mode specific prepare steps under _model_builder_deploy_wrapper
389394
self.s3_upload_path = None
@@ -401,6 +406,7 @@ def _prepare_for_mode(self):
401406
self.sagemaker_session,
402407
self.image_uri,
403408
getattr(self, "model_hub", None) == ModelHub.JUMPSTART,
409+
should_upload=should_upload_artifacts,
404410
)
405411
self.env_vars.update(env_vars_sagemaker)
406412
return self.s3_upload_path, env_vars_sagemaker
@@ -479,6 +485,10 @@ def _create_model(self):
479485
self.pysdk_model.mode = self.mode
480486
self.pysdk_model.modes = self.modes
481487
self.pysdk_model.serve_settings = self.serve_settings
488+
if self.role_arn:
489+
self.pysdk_model.role = self.role_arn
490+
if self.sagemaker_session:
491+
self.pysdk_model.sagemaker_session = self.sagemaker_session
482492

483493
# dynamically generate a method to direct model.deploy() logic based on mode
484494
# unique method to models created via ModelBuilder()
@@ -935,8 +945,9 @@ def optimize(self, *args, **kwargs) -> Model:
935945
"""Runs a model optimization job.
936946
937947
Args:
938-
instance_type (str): Target deployment instance type that the model is optimized for.
939-
output_path (str): Specifies where to store the compiled/quantized model.
948+
instance_type (Optional[str]): Target deployment instance type that the
949+
model is optimized for.
950+
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
940951
role (Optional[str]): Execution role. Defaults to ``None``.
941952
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
942953
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
@@ -964,7 +975,7 @@ def optimize(self, *args, **kwargs) -> Model:
964975
@_capture_telemetry("optimize")
965976
def _model_builder_optimize_wrapper(
966977
self,
967-
output_path: str,
978+
output_path: Optional[str] = None,
968979
instance_type: Optional[str] = None,
969980
role: Optional[str] = None,
970981
tags: Optional[Tags] = None,
@@ -1010,11 +1021,15 @@ def _model_builder_optimize_wrapper(
10101021
Returns:
10111022
Model: A deployable ``Model`` object.
10121023
"""
1024+
_validate_optimization_inputs(
1025+
output_path, instance_type, quantization_config, compilation_config
1026+
)
1027+
10131028
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
10141029
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
10151030
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
10161031

1017-
input_args = {}
1032+
input_args = None
10181033
if self._is_jumpstart_model_id():
10191034
input_args = self._optimize_for_jumpstart(
10201035
output_path=output_path,
@@ -1032,8 +1047,9 @@ def _model_builder_optimize_wrapper(
10321047
max_runtime_in_sec=max_runtime_in_sec,
10331048
)
10341049

1035-
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
1036-
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1037-
self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status)
1050+
if input_args:
1051+
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
1052+
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1053+
self.pysdk_model = _generate_optimized_model(self.pysdk_model, job_status)
10381054

10391055
return self.pysdk_model

src/sagemaker/serve/builder/tei_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
162162
self.pysdk_model.role = kwargs.get("role")
163163
del kwargs["role"]
164164

165-
# set model_data to uncompressed s3 dict
166-
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
167-
self.env_vars.update(env_vars)
168-
self.pysdk_model.env.update(self.env_vars)
165+
self._prepare_for_mode()
169166

170167
# if the weights have been cached via local container mode -> set to offline
171168
if str(Mode.LOCAL_CONTAINER) in self.modes:
@@ -220,4 +217,8 @@ def _build_for_tei(self):
220217
self._set_to_tei()
221218

222219
self.pysdk_model = self._build_for_hf_tei()
220+
if self.role_arn:
221+
self.pysdk_model.role = self.role_arn
222+
if self.sagemaker_session:
223+
self.pysdk_model.sagemaker_session = self.sagemaker_session
223224
return self.pysdk_model

src/sagemaker/serve/builder/tf_serving_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def _create_tensorflow_model(self):
102102
self.pysdk_model.mode = self.mode
103103
self.pysdk_model.modes = self.modes
104104
self.pysdk_model.serve_settings = self.serve_settings
105+
if hasattr(self, "role_arn") and self.role_arn:
106+
self.pysdk_model.role = self.role_arn
107+
if hasattr(self, "sagemaker_session") and self.sagemaker_session:
108+
self.pysdk_model.sagemaker_session = self.sagemaker_session
105109

106110
self._original_deploy = self.pysdk_model.deploy
107111
self.pysdk_model.deploy = self._model_builder_deploy_wrapper

src/sagemaker/serve/builder/tgi_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,7 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
201201
self.pysdk_model.role = kwargs.get("role")
202202
del kwargs["role"]
203203

204-
# set model_data to uncompressed s3 dict
205-
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
206-
self.env_vars.update(env_vars)
207-
self.pysdk_model.env.update(self.env_vars)
204+
self._prepare_for_mode()
208205

209206
# if the weights have been cached via local container mode -> set to offline
210207
if str(Mode.LOCAL_CONTAINER) in self.modes:
@@ -472,4 +469,8 @@ def _build_for_tgi(self):
472469

473470
self.pysdk_model = self._build_for_hf_tgi()
474471
self.pysdk_model.tune = self._tune_for_hf_tgi
472+
if self.role_arn:
473+
self.pysdk_model.role = self.role_arn
474+
if self.sagemaker_session:
475+
self.pysdk_model.sagemaker_session = self.sagemaker_session
475476
return self.pysdk_model

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,7 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
223223
self.pysdk_model.role = kwargs.get("role")
224224
del kwargs["role"]
225225

226-
# set model_data to uncompressed s3 dict
227-
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
228-
self.env_vars.update(env_vars)
229-
self.pysdk_model.env.update(self.env_vars)
226+
self._prepare_for_mode()
230227

231228
if "endpoint_logging" not in kwargs:
232229
kwargs["endpoint_logging"] = True
@@ -303,4 +300,8 @@ def _build_for_transformers(self):
303300

304301
self._build_transformers_env()
305302

303+
if self.role_arn:
304+
self.pysdk_model.role = self.role_arn
305+
if self.sagemaker_session:
306+
self.pysdk_model.sagemaker_session = self.sagemaker_session
306307
return self.pysdk_model

0 commit comments

Comments
 (0)