Skip to content

Commit b9c206f

Browse files
makungaj1Jonathan Makunga
andauthored
fix: ModelBuilder not passing HF_TOKEN to model. (#4780)
* Follow-ups fixes * Refactoring * Unit tests * refactoring * Refactoring * Refactoring --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 4c5dd1f commit b9c206f

File tree

5 files changed

+93
-108
lines changed

5 files changed

+93
-108
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,6 @@ def _optimize_for_jumpstart(
669669
self,
670670
output_path: Optional[str] = None,
671671
instance_type: Optional[str] = None,
672-
role_arn: Optional[str] = None,
673672
tags: Optional[Tags] = None,
674673
job_name: Optional[str] = None,
675674
accept_eula: Optional[bool] = None,
@@ -685,9 +684,7 @@ def _optimize_for_jumpstart(
685684
686685
Args:
687686
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
688-
instance_type (Optional[str]): Target deployment instance type that
689-
the model is optimized for.
690-
role_arn (Optional[str]): Execution role. Defaults to ``None``.
687+
instance_type (str): Target deployment instance type that the model is optimized for.
691688
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
692689
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
693690
accept_eula (bool): For models that require a Model Access Config, specify True or
@@ -715,7 +712,7 @@ def _optimize_for_jumpstart(
715712
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
716713
)
717714

718-
is_compilation = (quantization_config is None) and (
715+
is_compilation = (not quantization_config) and (
719716
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
720717
)
721718

@@ -758,7 +755,6 @@ def _optimize_for_jumpstart(
758755
else None
759756
)
760757
self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance()
761-
self.role_arn = role_arn or self.role_arn
762758

763759
create_optimization_job_args = {
764760
"OptimizationJobName": job_name,
@@ -787,10 +783,10 @@ def _optimize_for_jumpstart(
787783
"AcceptEula": True
788784
}
789785

786+
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
787+
if optimization_env_vars:
788+
self.pysdk_model.env.update(optimization_env_vars)
790789
if quantization_config or is_compilation:
791-
self.pysdk_model.env = _update_environment_variables(
792-
optimization_env_vars, override_env
793-
)
794790
return create_optimization_job_args
795791
return None
796792

src/sagemaker/serve/builder/model_builder.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
_generate_model_source,
7474
_extract_optimization_config_and_env,
7575
_is_s3_uri,
76-
_normalize_local_model_path,
7776
_custom_speculative_decoding,
7877
_extract_speculative_draft_model_provider,
7978
)
@@ -833,6 +832,8 @@ def build( # pylint: disable=R0911
833832
# until we deprecate HUGGING_FACE_HUB_TOKEN.
834833
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN") and not self.env_vars.get("HF_TOKEN"):
835834
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
835+
elif self.env_vars.get("HF_TOKEN") and not self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
836+
self.env_vars["HUGGING_FACE_HUB_TOKEN"] = self.env_vars.get("HF_TOKEN")
836837

837838
self.sagemaker_session.settings._local_download_dir = self.model_path
838839

@@ -851,7 +852,9 @@ def build( # pylint: disable=R0911
851852

852853
self._build_validations()
853854

854-
if not self._is_jumpstart_model_id() and self.model_server:
855+
if (
856+
not (isinstance(self.model, str) and self._is_jumpstart_model_id())
857+
) and self.model_server:
855858
return self._build_for_model_server()
856859

857860
if isinstance(self.model, str):
@@ -1216,18 +1219,15 @@ def _model_builder_optimize_wrapper(
12161219
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12171220

12181221
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
1219-
12201222
self.instance_type = instance_type or self.instance_type
12211223
self.role_arn = role_arn or self.role_arn
12221224

1223-
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12241225
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
1225-
12261226
if self._is_jumpstart_model_id():
1227+
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12271228
input_args = self._optimize_for_jumpstart(
12281229
output_path=output_path,
12291230
instance_type=instance_type,
1230-
role_arn=self.role_arn,
12311231
tags=tags,
12321232
job_name=job_name,
12331233
accept_eula=accept_eula,
@@ -1240,10 +1240,13 @@ def _model_builder_optimize_wrapper(
12401240
max_runtime_in_sec=max_runtime_in_sec,
12411241
)
12421242
else:
1243+
if self.model_server != ModelServer.DJL_SERVING:
1244+
logger.info("Overriding model server to DJL_SERVING.")
1245+
self.model_server = ModelServer.DJL_SERVING
1246+
1247+
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12431248
input_args = self._optimize_for_hf(
12441249
output_path=output_path,
1245-
instance_type=instance_type,
1246-
role_arn=self.role_arn,
12471250
tags=tags,
12481251
job_name=job_name,
12491252
quantization_config=quantization_config,
@@ -1269,8 +1272,6 @@ def _model_builder_optimize_wrapper(
12691272
def _optimize_for_hf(
12701273
self,
12711274
output_path: str,
1272-
instance_type: Optional[str] = None,
1273-
role_arn: Optional[str] = None,
12741275
tags: Optional[Tags] = None,
12751276
job_name: Optional[str] = None,
12761277
quantization_config: Optional[Dict] = None,
@@ -1285,9 +1286,6 @@ def _optimize_for_hf(
12851286
12861287
Args:
12871288
output_path (str): Specifies where to store the compiled/quantized model.
1288-
instance_type (Optional[str]): Target deployment instance type that
1289-
the model is optimized for.
1290-
role_arn (Optional[str]): Execution role. Defaults to ``None``.
12911289
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
12921290
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
12931291
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -1305,13 +1303,6 @@ def _optimize_for_hf(
13051303
Returns:
13061304
Optional[Dict[str, Any]]: Model optimization job input arguments.
13071305
"""
1308-
if self.model_server != ModelServer.DJL_SERVING:
1309-
logger.info("Overwriting model server to DJL.")
1310-
self.model_server = ModelServer.DJL_SERVING
1311-
1312-
self.role_arn = role_arn or self.role_arn
1313-
self.instance_type = instance_type or self.instance_type
1314-
13151306
self.pysdk_model = _custom_speculative_decoding(
13161307
self.pysdk_model, speculative_decoding_config, False
13171308
)
@@ -1371,13 +1362,12 @@ def _optimize_prepare_for_hf(self):
13711362
)
13721363
else:
13731364
if not custom_model_path:
1374-
custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}/code"
1365+
custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}"
13751366
download_huggingface_model_metadata(
13761367
self.model,
1377-
custom_model_path,
1368+
os.path.join(custom_model_path, "code"),
13781369
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
13791370
)
1380-
custom_model_path = _normalize_local_model_path(custom_model_path)
13811371

13821372
self.pysdk_model.model_data, env = self._prepare_for_mode(
13831373
model_path=custom_model_path,

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,6 @@ def _extract_optimization_config_and_env(
282282
return None, None
283283

284284

285-
def _normalize_local_model_path(local_model_path: Optional[str]) -> Optional[str]:
286-
"""Normalizes the local model path.
287-
288-
Args:
289-
local_model_path (Optional[str]): The local model path.
290-
291-
Returns:
292-
Optional[str]: The normalized model path.
293-
"""
294-
if local_model_path is None:
295-
return local_model_path
296-
297-
# Removes /code or /code/ path at the end of local_model_path,
298-
# as it is appended during artifacts upload.
299-
pattern = r"/code/?$"
300-
if re.search(pattern, local_model_path):
301-
return re.sub(pattern, "", local_model_path)
302-
return local_model_path
303-
304-
305285
def _custom_speculative_decoding(
306286
model: Model,
307287
speculative_decoding_config: Optional[Dict],

0 commit comments

Comments
 (0)