73
73
_generate_model_source ,
74
74
_extract_optimization_config_and_env ,
75
75
_is_s3_uri ,
76
- _normalize_local_model_path ,
77
76
_custom_speculative_decoding ,
78
77
_extract_speculative_draft_model_provider ,
79
78
)
@@ -833,6 +832,8 @@ def build( # pylint: disable=R0911
833
832
# until we deprecate HUGGING_FACE_HUB_TOKEN.
834
833
if self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ) and not self .env_vars .get ("HF_TOKEN" ):
835
834
self .env_vars ["HF_TOKEN" ] = self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
835
+ elif self .env_vars .get ("HF_TOKEN" ) and not self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ):
836
+ self .env_vars ["HUGGING_FACE_HUB_TOKEN" ] = self .env_vars .get ("HF_TOKEN" )
836
837
837
838
self .sagemaker_session .settings ._local_download_dir = self .model_path
838
839
@@ -851,7 +852,9 @@ def build( # pylint: disable=R0911
851
852
852
853
self ._build_validations ()
853
854
854
- if not self ._is_jumpstart_model_id () and self .model_server :
855
+ if (
856
+ not (isinstance (self .model , str ) and self ._is_jumpstart_model_id ())
857
+ ) and self .model_server :
855
858
return self ._build_for_model_server ()
856
859
857
860
if isinstance (self .model , str ):
@@ -1216,18 +1219,15 @@ def _model_builder_optimize_wrapper(
1216
1219
raise ValueError ("Quantization config and compilation config are mutually exclusive." )
1217
1220
1218
1221
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1219
-
1220
1222
self .instance_type = instance_type or self .instance_type
1221
1223
self .role_arn = role_arn or self .role_arn
1222
1224
1223
- self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1224
1225
job_name = job_name or f"modelbuilderjob-{ uuid .uuid4 ().hex } "
1225
-
1226
1226
if self ._is_jumpstart_model_id ():
1227
+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1227
1228
input_args = self ._optimize_for_jumpstart (
1228
1229
output_path = output_path ,
1229
1230
instance_type = instance_type ,
1230
- role_arn = self .role_arn ,
1231
1231
tags = tags ,
1232
1232
job_name = job_name ,
1233
1233
accept_eula = accept_eula ,
@@ -1240,10 +1240,13 @@ def _model_builder_optimize_wrapper(
1240
1240
max_runtime_in_sec = max_runtime_in_sec ,
1241
1241
)
1242
1242
else :
1243
+ if self .model_server != ModelServer .DJL_SERVING :
1244
+ logger .info ("Overriding model server to DJL_SERVING." )
1245
+ self .model_server = ModelServer .DJL_SERVING
1246
+
1247
+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1243
1248
input_args = self ._optimize_for_hf (
1244
1249
output_path = output_path ,
1245
- instance_type = instance_type ,
1246
- role_arn = self .role_arn ,
1247
1250
tags = tags ,
1248
1251
job_name = job_name ,
1249
1252
quantization_config = quantization_config ,
@@ -1269,8 +1272,6 @@ def _model_builder_optimize_wrapper(
1269
1272
def _optimize_for_hf (
1270
1273
self ,
1271
1274
output_path : str ,
1272
- instance_type : Optional [str ] = None ,
1273
- role_arn : Optional [str ] = None ,
1274
1275
tags : Optional [Tags ] = None ,
1275
1276
job_name : Optional [str ] = None ,
1276
1277
quantization_config : Optional [Dict ] = None ,
@@ -1285,9 +1286,6 @@ def _optimize_for_hf(
1285
1286
1286
1287
Args:
1287
1288
output_path (str): Specifies where to store the compiled/quantized model.
1288
- instance_type (Optional[str]): Target deployment instance type that
1289
- the model is optimized for.
1290
- role_arn (Optional[str]): Execution role. Defaults to ``None``.
1291
1289
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
1292
1290
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
1293
1291
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -1305,13 +1303,6 @@ def _optimize_for_hf(
1305
1303
Returns:
1306
1304
Optional[Dict[str, Any]]: Model optimization job input arguments.
1307
1305
"""
1308
- if self .model_server != ModelServer .DJL_SERVING :
1309
- logger .info ("Overwriting model server to DJL." )
1310
- self .model_server = ModelServer .DJL_SERVING
1311
-
1312
- self .role_arn = role_arn or self .role_arn
1313
- self .instance_type = instance_type or self .instance_type
1314
-
1315
1306
self .pysdk_model = _custom_speculative_decoding (
1316
1307
self .pysdk_model , speculative_decoding_config , False
1317
1308
)
@@ -1371,13 +1362,12 @@ def _optimize_prepare_for_hf(self):
1371
1362
)
1372
1363
else :
1373
1364
if not custom_model_path :
1374
- custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } /code "
1365
+ custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } "
1375
1366
download_huggingface_model_metadata (
1376
1367
self .model ,
1377
- custom_model_path ,
1368
+ os . path . join ( custom_model_path , "code" ) ,
1378
1369
self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ),
1379
1370
)
1380
- custom_model_path = _normalize_local_model_path (custom_model_path )
1381
1371
1382
1372
self .pysdk_model .model_data , env = self ._prepare_for_mode (
1383
1373
model_path = custom_model_path ,
0 commit comments