13
13
"""Placeholder docstring"""
14
14
from __future__ import absolute_import
15
15
16
+ import copy
16
17
from abc import ABC , abstractmethod
18
+ from datetime import datetime , timedelta
17
19
from typing import Type
18
20
import logging
19
21
20
22
from sagemaker .model import Model
21
23
from sagemaker import model_uris
22
24
from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
25
+ from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
23
26
from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
24
27
from sagemaker .serve .mode .function_pointers import Mode
28
+ from sagemaker .serve .utils .exceptions import (
29
+ LocalDeepPingException ,
30
+ LocalModelOutOfMemoryException ,
31
+ LocalModelInvocationException ,
32
+ LocalModelLoadException ,
33
+ SkipTuningComboException ,
34
+ )
25
35
from sagemaker .serve .utils .predictors import (
26
36
DjlLocalModePredictor ,
27
37
TgiLocalModePredictor ,
28
38
)
29
- from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
39
+ from sagemaker .serve .utils .local_hardware import (
40
+ _get_nb_instance ,
41
+ _get_ram_usage_mb ,
42
+ )
30
43
from sagemaker .serve .utils .telemetry_logger import _capture_telemetry
44
+ from sagemaker .serve .utils .tuning import (
45
+ _pretty_print_results_jumpstart ,
46
+ _serial_benchmark ,
47
+ _concurrent_benchmark ,
48
+ _more_performant ,
49
+ _sharded_supported ,
50
+ )
31
51
from sagemaker .serve .utils .types import ModelServer
32
52
from sagemaker .base_predictor import PredictorBase
33
53
from sagemaker .jumpstart .model import JumpStartModel
@@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
134
154
model_data = self .pysdk_model .model_data ,
135
155
)
136
156
elif not hasattr (self , "prepared_for_tgi" ):
137
- self .prepared_for_tgi = prepare_tgi_js_resources (
157
+ self .js_model_config , self . prepared_for_tgi = prepare_tgi_js_resources (
138
158
model_path = self .model_path ,
139
159
js_id = self .model ,
140
160
dependencies = self .dependencies ,
@@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
222
242
env = {}
223
243
if self .mode == Mode .LOCAL_CONTAINER :
224
244
if not hasattr (self , "prepared_for_tgi" ):
225
- self .prepared_for_tgi = prepare_tgi_js_resources (
245
+ self .js_model_config , self . prepared_for_tgi = prepare_tgi_js_resources (
226
246
model_path = self .model_path ,
227
247
js_id = self .model ,
228
248
dependencies = self .dependencies ,
@@ -234,6 +254,183 @@ def _build_for_tgi_jumpstart(self):
234
254
235
255
self .pysdk_model .env .update (env )
236
256
257
+ def _tune_for_js (self , sharded_supported : bool , max_tuning_duration : int = 1800 ):
258
+ """Tune for Jumpstart Models in Local Mode.
259
+
260
+ Args:
261
+ sharded_supported (bool): Indicates whether sharding is supported by this ``Model``
262
+ max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
263
+ Default: ``1800``
264
+ returns:
265
+ Tuned Model.
266
+ """
267
+ if self .mode != Mode .LOCAL_CONTAINER :
268
+ logger .warning (
269
+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
270
+ )
271
+ return self .pysdk_model
272
+
273
+ num_shard_env_var_name = "SM_NUM_GPUS"
274
+ if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
275
+ num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
276
+
277
+ initial_env_vars = copy .deepcopy (self .pysdk_model .env )
278
+ admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees (
279
+ self .js_model_config
280
+ )
281
+
282
+ if len (admissible_tensor_parallel_degrees ) > 1 and not sharded_supported :
283
+ admissible_tensor_parallel_degrees = [1 ]
284
+ logger .warning (
285
+ "Sharding across multiple GPUs is not supported for this model. "
286
+ "Model can only be sharded across [1] GPU"
287
+ )
288
+
289
+ benchmark_results = {}
290
+ best_tuned_combination = None
291
+ timeout = datetime .now () + timedelta (seconds = max_tuning_duration )
292
+ for tensor_parallel_degree in admissible_tensor_parallel_degrees :
293
+ if datetime .now () > timeout :
294
+ logger .info ("Max tuning duration reached. Tuning stopped." )
295
+ break
296
+
297
+ self .pysdk_model .env .update ({num_shard_env_var_name : str (tensor_parallel_degree )})
298
+ try :
299
+ logger .info ("Trying tensor parallel degree: %s" , tensor_parallel_degree )
300
+
301
+ predictor = self .pysdk_model .deploy (model_data_download_timeout = max_tuning_duration )
302
+
303
+ avg_latency , p90 , avg_tokens_per_second = _serial_benchmark (
304
+ predictor , self .schema_builder .sample_input
305
+ )
306
+ throughput_per_second , standard_deviation = _concurrent_benchmark (
307
+ predictor , self .schema_builder .sample_input
308
+ )
309
+
310
+ tested_env = copy .deepcopy (self .pysdk_model .env )
311
+ logger .info (
312
+ "Average latency: %s, throughput/s: %s for configuration: %s" ,
313
+ avg_latency ,
314
+ throughput_per_second ,
315
+ tested_env ,
316
+ )
317
+ benchmark_results [avg_latency ] = [
318
+ tested_env ,
319
+ p90 ,
320
+ avg_tokens_per_second ,
321
+ throughput_per_second ,
322
+ standard_deviation ,
323
+ ]
324
+
325
+ if not best_tuned_combination :
326
+ best_tuned_combination = [
327
+ avg_latency ,
328
+ tensor_parallel_degree ,
329
+ None ,
330
+ p90 ,
331
+ avg_tokens_per_second ,
332
+ throughput_per_second ,
333
+ standard_deviation ,
334
+ ]
335
+ else :
336
+ tuned_configuration = [
337
+ avg_latency ,
338
+ tensor_parallel_degree ,
339
+ None ,
340
+ p90 ,
341
+ avg_tokens_per_second ,
342
+ throughput_per_second ,
343
+ standard_deviation ,
344
+ ]
345
+ if _more_performant (best_tuned_combination , tuned_configuration ):
346
+ best_tuned_combination = tuned_configuration
347
+ except LocalDeepPingException as e :
348
+ logger .warning (
349
+ "Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s" ,
350
+ num_shard_env_var_name ,
351
+ tensor_parallel_degree ,
352
+ str (e ),
353
+ )
354
+ except LocalModelOutOfMemoryException as e :
355
+ logger .warning (
356
+ "Deployment unsuccessful with %s: %s. "
357
+ "Out of memory when loading the model: %s" ,
358
+ num_shard_env_var_name ,
359
+ tensor_parallel_degree ,
360
+ str (e ),
361
+ )
362
+ except LocalModelInvocationException as e :
363
+ logger .warning (
364
+ "Deployment unsuccessful with %s: %s. "
365
+ "Failed to invoke the model server: %s"
366
+ "Please check that model server configurations are as expected "
367
+ "(Ex. serialization, deserialization, content_type, accept)." ,
368
+ num_shard_env_var_name ,
369
+ tensor_parallel_degree ,
370
+ str (e ),
371
+ )
372
+ except LocalModelLoadException as e :
373
+ logger .warning (
374
+ "Deployment unsuccessful with %s: %s. " "Failed to load the model: %s." ,
375
+ num_shard_env_var_name ,
376
+ tensor_parallel_degree ,
377
+ str (e ),
378
+ )
379
+ except SkipTuningComboException as e :
380
+ logger .warning (
381
+ "Deployment with %s: %s"
382
+ "was expected to be successful. However failed with: %s. "
383
+ "Trying next combination." ,
384
+ num_shard_env_var_name ,
385
+ tensor_parallel_degree ,
386
+ str (e ),
387
+ )
388
+ except Exception : # pylint: disable=W0703
389
+ logger .exception (
390
+ "Deployment unsuccessful with %s: %s. " "with uncovered exception" ,
391
+ num_shard_env_var_name ,
392
+ tensor_parallel_degree ,
393
+ )
394
+
395
+ if best_tuned_combination :
396
+ self .pysdk_model .env .update ({num_shard_env_var_name : str (best_tuned_combination [1 ])})
397
+
398
+ _pretty_print_results_jumpstart (benchmark_results , [num_shard_env_var_name ])
399
+ logger .info (
400
+ "Model Configuration: %s was most performant with avg latency: %s, "
401
+ "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
402
+ "standard deviation of request %s" ,
403
+ self .pysdk_model .env ,
404
+ best_tuned_combination [0 ],
405
+ best_tuned_combination [3 ],
406
+ best_tuned_combination [4 ],
407
+ best_tuned_combination [5 ],
408
+ best_tuned_combination [6 ],
409
+ )
410
+ else :
411
+ self .pysdk_model .env .update (initial_env_vars )
412
+ logger .debug (
413
+ "Failed to gather any tuning results. "
414
+ "Please inspect the stack trace emitted from live logging for more details. "
415
+ "Falling back to default model configurations: %s" ,
416
+ self .pysdk_model .env ,
417
+ )
418
+
419
+ return self .pysdk_model
420
+
421
+ @_capture_telemetry ("djl_jumpstart.tune" )
422
+ def tune_for_djl_jumpstart (self , max_tuning_duration : int = 1800 ):
423
+ """Tune for Jumpstart Models with DJL DLC"""
424
+ return self ._tune_for_js (sharded_supported = True , max_tuning_duration = max_tuning_duration )
425
+
426
+ @_capture_telemetry ("tgi_jumpstart.tune" )
427
+ def tune_for_tgi_jumpstart (self , max_tuning_duration : int = 1800 ):
428
+ """Tune for Jumpstart Models with TGI DLC"""
429
+ sharded_supported = _sharded_supported (self .model , self .js_model_config )
430
+ return self ._tune_for_js (
431
+ sharded_supported = sharded_supported , max_tuning_duration = max_tuning_duration
432
+ )
433
+
237
434
def _build_for_jumpstart (self ):
238
435
"""Placeholder docstring"""
239
436
# we do not pickle for jumpstart. set to none
@@ -254,6 +451,8 @@ def _build_for_jumpstart(self):
254
451
self .image_uri = self .pysdk_model .image_uri
255
452
256
453
self ._build_for_djl_jumpstart ()
454
+
455
+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
257
456
elif "tgi-inference" in image_uri :
258
457
logger .info ("Building for TGI JumpStart Model ID..." )
259
458
self .model_server = ModelServer .TGI
@@ -262,6 +461,8 @@ def _build_for_jumpstart(self):
262
461
self .image_uri = self .pysdk_model .image_uri
263
462
264
463
self ._build_for_tgi_jumpstart ()
464
+
465
+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
265
466
else :
266
467
raise ValueError (
267
468
"JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
0 commit comments