13
13
"""Placeholder docstring"""
14
14
from __future__ import absolute_import
15
15
16
+ import copy
16
17
from abc import ABC , abstractmethod
18
+ from datetime import datetime , timedelta
17
19
from typing import Type
18
20
import logging
19
21
20
22
from sagemaker .model import Model
21
23
from sagemaker import model_uris
22
24
from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
25
+ from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
23
26
from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
24
27
from sagemaker .serve .mode .function_pointers import Mode
28
+ from sagemaker .serve .utils .exceptions import (
29
+ LocalDeepPingException ,
30
+ LocalModelOutOfMemoryException ,
31
+ LocalModelInvocationException ,
32
+ LocalModelLoadException ,
33
+ SkipTuningComboException ,
34
+ )
25
35
from sagemaker .serve .utils .predictors import (
26
36
DjlLocalModePredictor ,
27
37
TgiLocalModePredictor ,
28
38
)
29
- from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
39
+ from sagemaker .serve .utils .local_hardware import (
40
+ _get_nb_instance ,
41
+ _get_ram_usage_mb ,
42
+ _get_available_gpus ,
43
+ )
30
44
from sagemaker .serve .utils .telemetry_logger import _capture_telemetry
45
+ from sagemaker .serve .utils .tuning import (
46
+ _serial_benchmark ,
47
+ _concurrent_benchmark ,
48
+ _more_performant ,
49
+ _pretty_print_benchmark_results ,
50
+ )
31
51
from sagemaker .serve .utils .types import ModelServer
32
52
from sagemaker .base_predictor import PredictorBase
33
53
from sagemaker .jumpstart .model import JumpStartModel
@@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
134
154
model_data = self .pysdk_model .model_data ,
135
155
)
136
156
elif not hasattr (self , "prepared_for_tgi" ):
137
- self .prepared_for_tgi = prepare_tgi_js_resources (
157
+ ( self .js_model_config , self . prepared_for_tgi ) = prepare_tgi_js_resources (
138
158
model_path = self .model_path ,
139
159
js_id = self .model ,
140
160
dependencies = self .dependencies ,
@@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
222
242
env = {}
223
243
if self .mode == Mode .LOCAL_CONTAINER :
224
244
if not hasattr (self , "prepared_for_tgi" ):
225
- self .prepared_for_tgi = prepare_tgi_js_resources (
245
+ ( self .js_model_config , self . prepared_for_tgi ) = prepare_tgi_js_resources (
226
246
model_path = self .model_path ,
227
247
js_id = self .model ,
228
248
dependencies = self .dependencies ,
@@ -234,6 +254,213 @@ def _build_for_tgi_jumpstart(self):
234
254
235
255
self .pysdk_model .env .update (env )
236
256
257
+ def _tune_for_js (
258
+ self , multiple_model_copies_enabled : bool = False , max_tuning_duration : int = 1800
259
+ ):
260
+ """Tune for Jumpstart Models.
261
+
262
+ Args:
263
+ multiple_model_copies_enabled (bool): Whether multiple model copies serving is enable by
264
+ this ``DLC``. Defaults to ``False``
265
+ max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
266
+ Default: ``1800``
267
+ returns:
268
+ Tuned Model.
269
+ """
270
+ if self .mode != Mode .LOCAL_CONTAINER :
271
+ logger .warning (
272
+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
273
+ )
274
+ return self .pysdk_model
275
+
276
+ num_shard_env_var_name = "SM_NUM_GPUS"
277
+ if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
278
+ num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
279
+
280
+ initial_env_vars = copy .deepcopy (self .pysdk_model .env )
281
+ admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees (
282
+ self .js_model_config
283
+ )
284
+ available_gpus = _get_available_gpus () if multiple_model_copies_enabled else None
285
+
286
+ benchmark_results = {}
287
+ best_tuned_combination = None
288
+ timeout = datetime .now () + timedelta (seconds = max_tuning_duration )
289
+ for tensor_parallel_degree in admissible_tensor_parallel_degrees :
290
+ if datetime .now () > timeout :
291
+ logger .info ("Max tuning duration reached. Tuning stopped." )
292
+ break
293
+
294
+ sagemaker_model_server_workers = 1
295
+ if multiple_model_copies_enabled :
296
+ sagemaker_model_server_workers = int (available_gpus / tensor_parallel_degree )
297
+
298
+ self .pysdk_model .env .update (
299
+ {
300
+ num_shard_env_var_name : str (tensor_parallel_degree ),
301
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : str (sagemaker_model_server_workers ),
302
+ }
303
+ )
304
+
305
+ try :
306
+ predictor = self .pysdk_model .deploy (model_data_download_timeout = max_tuning_duration )
307
+
308
+ avg_latency , p90 , avg_tokens_per_second = _serial_benchmark (
309
+ predictor , self .schema_builder .sample_input
310
+ )
311
+ throughput_per_second , standard_deviation = _concurrent_benchmark (
312
+ predictor , self .schema_builder .sample_input
313
+ )
314
+
315
+ tested_env = self .pysdk_model .env .copy ()
316
+ logger .info (
317
+ "Average latency: %s, throughput/s: %s for configuration: %s" ,
318
+ avg_latency ,
319
+ throughput_per_second ,
320
+ tested_env ,
321
+ )
322
+ benchmark_results [avg_latency ] = [
323
+ tested_env ,
324
+ p90 ,
325
+ avg_tokens_per_second ,
326
+ throughput_per_second ,
327
+ standard_deviation ,
328
+ ]
329
+
330
+ if not best_tuned_combination :
331
+ best_tuned_combination = [
332
+ avg_latency ,
333
+ tensor_parallel_degree ,
334
+ sagemaker_model_server_workers ,
335
+ p90 ,
336
+ avg_tokens_per_second ,
337
+ throughput_per_second ,
338
+ standard_deviation ,
339
+ ]
340
+ else :
341
+ tuned_configuration = [
342
+ avg_latency ,
343
+ tensor_parallel_degree ,
344
+ sagemaker_model_server_workers ,
345
+ p90 ,
346
+ avg_tokens_per_second ,
347
+ throughput_per_second ,
348
+ standard_deviation ,
349
+ ]
350
+ if _more_performant (best_tuned_combination , tuned_configuration ):
351
+ best_tuned_combination = tuned_configuration
352
+ except LocalDeepPingException as e :
353
+ logger .warning (
354
+ "Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s"
355
+ "Failed to invoke the model server: %s" ,
356
+ num_shard_env_var_name ,
357
+ tensor_parallel_degree ,
358
+ sagemaker_model_server_workers ,
359
+ str (e ),
360
+ )
361
+ break
362
+ except LocalModelOutOfMemoryException as e :
363
+ logger .warning (
364
+ "Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
365
+ "Out of memory when loading the model: %s" ,
366
+ num_shard_env_var_name ,
367
+ tensor_parallel_degree ,
368
+ sagemaker_model_server_workers ,
369
+ str (e ),
370
+ )
371
+ break
372
+ except LocalModelInvocationException as e :
373
+ logger .warning (
374
+ "Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
375
+ "Failed to invoke the model server: %s"
376
+ "Please check that model server configurations are as expected "
377
+ "(Ex. serialization, deserialization, content_type, accept)." ,
378
+ num_shard_env_var_name ,
379
+ tensor_parallel_degree ,
380
+ sagemaker_model_server_workers ,
381
+ str (e ),
382
+ )
383
+ break
384
+ except LocalModelLoadException as e :
385
+ logger .warning (
386
+ "Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
387
+ "Failed to load the model: %s." ,
388
+ num_shard_env_var_name ,
389
+ tensor_parallel_degree ,
390
+ sagemaker_model_server_workers ,
391
+ str (e ),
392
+ )
393
+ break
394
+ except SkipTuningComboException as e :
395
+ logger .warning (
396
+ "Deployment with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
397
+ "was expected to be successful. However failed with: %s. "
398
+ "Trying next combination." ,
399
+ num_shard_env_var_name ,
400
+ tensor_parallel_degree ,
401
+ sagemaker_model_server_workers ,
402
+ str (e ),
403
+ )
404
+ break
405
+ except Exception : # pylint: disable=W0703
406
+ logger .exception (
407
+ "Deployment unsuccessful with %s: %s. SAGEMAKER_MODEL_SERVER_WORKERS: %s. "
408
+ "with uncovered exception" ,
409
+ num_shard_env_var_name ,
410
+ tensor_parallel_degree ,
411
+ sagemaker_model_server_workers ,
412
+ )
413
+ break
414
+
415
+ if best_tuned_combination :
416
+ self .pysdk_model .env .update (
417
+ {
418
+ num_shard_env_var_name : str (best_tuned_combination [1 ]),
419
+ "SAGEMAKER_MODEL_SERVER_WORKERS" : str (best_tuned_combination [2 ]),
420
+ }
421
+ )
422
+
423
+ _pretty_print_benchmark_results (
424
+ benchmark_results , [num_shard_env_var_name , "SAGEMAKER_MODEL_SERVER_WORKERS" ]
425
+ )
426
+ logger .info (
427
+ "Model Configuration: %s was most performant with avg latency: %s, "
428
+ "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
429
+ "standard deviation of request %s" ,
430
+ self .pysdk_model .env ,
431
+ best_tuned_combination [0 ],
432
+ best_tuned_combination [3 ],
433
+ best_tuned_combination [4 ],
434
+ best_tuned_combination [5 ],
435
+ best_tuned_combination [6 ],
436
+ )
437
+ else :
438
+ self .pysdk_model .env .update (initial_env_vars )
439
+ logger .debug (
440
+ "Failed to gather any tuning results. "
441
+ "Please inspect the stack trace emitted from live logging for more details. "
442
+ "Falling back to default model configurations: %s" ,
443
+ self .pysdk_model .env ,
444
+ )
445
+
446
+ return self .pysdk_model
447
+
448
+ @_capture_telemetry ("djl_jumpstart.tune" )
449
+ def tune_for_djl_jumpstart (self , max_tuning_duration : int = 1800 ):
450
+ """Tune for Jumpstart Models with DJL DLC"""
451
+ return self ._tune_for_js (
452
+ multiple_model_copies_enabled = True , max_tuning_duration = max_tuning_duration
453
+ )
454
+
455
+ @_capture_telemetry ("tgi_jumpstart.tune" )
456
+ def tune_for_tgi_jumpstart (self , max_tuning_duration : int = 1800 ):
457
+ """Tune for Jumpstart Models with TGI DLC"""
458
+ return self ._tune_for_js (
459
+ # Currently, TGI does not enable multiple model copies serving.
460
+ multiple_model_copies_enabled = False ,
461
+ max_tuning_duration = max_tuning_duration ,
462
+ )
463
+
237
464
def _build_for_jumpstart (self ):
238
465
"""Placeholder docstring"""
239
466
# we do not pickle for jumpstart. set to none
@@ -254,6 +481,8 @@ def _build_for_jumpstart(self):
254
481
self .image_uri = self .pysdk_model .image_uri
255
482
256
483
self ._build_for_djl_jumpstart ()
484
+
485
+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
257
486
elif "tgi-inference" in image_uri :
258
487
logger .info ("Building for TGI JumpStart Model ID..." )
259
488
self .model_server = ModelServer .TGI
@@ -262,6 +491,8 @@ def _build_for_jumpstart(self):
262
491
self .image_uri = self .pysdk_model .image_uri
263
492
264
493
self ._build_for_tgi_jumpstart ()
494
+
495
+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
265
496
else :
266
497
raise ValueError (
267
498
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
0 commit comments