@@ -1083,25 +1083,47 @@ def _try_fetch_gpu_info(self):
1083
1083
f"Unable to determine single GPU size for instance: [{ self .instance_type } ]"
1084
1084
)
1085
1085
1086
- def optimize (self , * args , ** kwargs ) -> Model :
1087
- """Runs a model optimization job.
1086
+ def optimize (
1087
+ self ,
1088
+ output_path : Optional [str ] = None ,
1089
+ instance_type : Optional [str ] = None ,
1090
+ role_arn : Optional [str ] = None ,
1091
+ tags : Optional [Tags ] = None ,
1092
+ job_name : Optional [str ] = None ,
1093
+ accept_eula : Optional [bool ] = None ,
1094
+ quantization_config : Optional [Dict ] = None ,
1095
+ compilation_config : Optional [Dict ] = None ,
1096
+ speculative_decoding_config : Optional [Dict ] = None ,
1097
+ env_vars : Optional [Dict ] = None ,
1098
+ vpc_config : Optional [Dict ] = None ,
1099
+ kms_key : Optional [str ] = None ,
1100
+ max_runtime_in_sec : Optional [int ] = 36000 ,
1101
+ sagemaker_session : Optional [Session ] = None ,
1102
+ ) -> Model :
1103
+ """Create an optimized deployable ``Model`` instance with ``ModelBuilder``.
1088
1104
1089
1105
Args:
1090
- instance_type (Optional[str]): Target deployment instance type that the
1091
- model is optimized for.
1092
- output_path (Optional[str]): Specifies where to store the compiled/quantized model.
1093
- role_arn (Optional[str]): Execution role. Defaults to ``None``.
1106
+ output_path (str): Specifies where to store the compiled/quantized model.
1107
+ instance_type (str): Target deployment instance type that the model is optimized for.
1108
+ role_arn (Optional[str]): Execution role arn. Defaults to ``None``.
1094
1109
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
1095
1110
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
1111
+ accept_eula (bool): For models that require a Model Access Config, specify True or
1112
+ False to indicate whether model terms of use have been accepted.
1113
+ The `accept_eula` value must be explicitly defined as `True` in order to
1114
+ accept the end-user license agreement (EULA) that some
1115
+ models require. (Default: None).
1096
1116
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
1097
1117
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1118
+ speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1119
+ Defaults to ``None``
1098
1120
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1099
1121
container. Defaults to ``None``.
1100
1122
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
1101
1123
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
1102
1124
to S3. Defaults to ``None``.
1103
1125
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
1104
- ``None`` .
1126
+ 36000 seconds .
1105
1127
sagemaker_session (Optional[Session]): Session object which manages interactions
1106
1128
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
1107
1129
function creates one using the default AWS configuration chain.
@@ -1113,7 +1135,22 @@ def optimize(self, *args, **kwargs) -> Model:
1113
1135
# need to get telemetry_opt_out info before telemetry decorator is called
1114
1136
self .serve_settings = self ._get_serve_setting ()
1115
1137
1116
- return self ._model_builder_optimize_wrapper (* args , ** kwargs )
1138
+ return self ._model_builder_optimize_wrapper (
1139
+ output_path = output_path ,
1140
+ instance_type = instance_type ,
1141
+ role_arn = role_arn ,
1142
+ tags = tags ,
1143
+ job_name = job_name ,
1144
+ accept_eula = accept_eula ,
1145
+ quantization_config = quantization_config ,
1146
+ compilation_config = compilation_config ,
1147
+ speculative_decoding_config = speculative_decoding_config ,
1148
+ env_vars = env_vars ,
1149
+ vpc_config = vpc_config ,
1150
+ kms_key = kms_key ,
1151
+ max_runtime_in_sec = max_runtime_in_sec ,
1152
+ sagemaker_session = sagemaker_session ,
1153
+ )
1117
1154
1118
1155
@_capture_telemetry ("optimize" )
1119
1156
def _model_builder_optimize_wrapper (
@@ -1178,10 +1215,8 @@ def _model_builder_optimize_wrapper(
1178
1215
1179
1216
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1180
1217
1181
- if instance_type :
1182
- self .instance_type = instance_type
1183
- if role_arn :
1184
- self .role_arn = role_arn
1218
+ self .instance_type = instance_type or self .instance_type
1219
+ self .role_arn = role_arn or self .role_arn
1185
1220
1186
1221
self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1187
1222
job_name = job_name or f"modelbuilderjob-{ uuid .uuid4 ().hex } "
@@ -1266,7 +1301,7 @@ def _optimize_for_hf(
1266
1301
``None``.
1267
1302
1268
1303
Returns:
1269
- Dict[str, Any]: Model optimization job input arguments.
1304
+ Optional[ Dict[str, Any] ]: Model optimization job input arguments.
1270
1305
"""
1271
1306
if self .model_server != ModelServer .DJL_SERVING :
1272
1307
logger .info ("Overwriting model server to DJL." )
@@ -1275,6 +1310,10 @@ def _optimize_for_hf(
1275
1310
self .role_arn = role_arn or self .role_arn
1276
1311
self .instance_type = instance_type or self .instance_type
1277
1312
1313
+ self .pysdk_model = _custom_speculative_decoding (
1314
+ self .pysdk_model , speculative_decoding_config , False
1315
+ )
1316
+
1278
1317
if quantization_config or compilation_config :
1279
1318
create_optimization_job_args = {
1280
1319
"OptimizationJobName" : job_name ,
@@ -1290,10 +1329,6 @@ def _optimize_for_hf(
1290
1329
model_source = _generate_model_source (self .pysdk_model .model_data , False )
1291
1330
create_optimization_job_args ["ModelSource" ] = model_source
1292
1331
1293
- self .pysdk_model = _custom_speculative_decoding (
1294
- self .pysdk_model , speculative_decoding_config , False
1295
- )
1296
-
1297
1332
optimization_config , override_env = _extract_optimization_config_and_env (
1298
1333
quantization_config , compilation_config
1299
1334
)
0 commit comments