@@ -246,7 +246,157 @@ def _build_for_tgi_jumpstart(self):
246
246
247
247
@_capture_telemetry ("djl_jumpstart.tune" )
248
248
def tune_for_djl_jumpstart (self , max_tuning_duration : int = 1800 ):
249
- pass
249
+ """pass"""
250
+ if self .mode != Mode .LOCAL_CONTAINER :
251
+ logger .warning (
252
+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
253
+ )
254
+ return self .pysdk_model
255
+
256
+ initial_model_configuration = copy .deepcopy (self .pysdk_model .env )
257
+
258
+ admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees (self .js_model_config )
259
+
260
+ benchmark_results = {}
261
+ best_tuned_combination = None
262
+ timeout = datetime .now () + timedelta (seconds = max_tuning_duration )
263
+ for tensor_parallel_degree in admissible_tensor_parallel_degrees :
264
+ if datetime .now () > timeout :
265
+ logger .info ("Max tuning duration reached. Tuning stopped." )
266
+ break
267
+
268
+ sagemaker_model_server_workers = None
269
+ self .pysdk_model .env .update ({
270
+ "OPTION_TENSOR_PARALLEL_DEGREE" : str (tensor_parallel_degree )
271
+ })
272
+
273
+ try :
274
+ predictor = self .pysdk_model .deploy (
275
+ model_data_download_timeout = max_tuning_duration
276
+ )
277
+
278
+ avg_latency , p90 , avg_tokens_per_second = _serial_benchmark (
279
+ predictor , self .schema_builder .sample_input
280
+ )
281
+ throughput_per_second , standard_deviation = _concurrent_benchmark (
282
+ predictor , self .schema_builder .sample_input
283
+ )
284
+
285
+ tested_env = self .pysdk_model .env .copy ()
286
+ logger .info (
287
+ "Average latency: %s, throughput/s: %s for configuration: %s" ,
288
+ avg_latency ,
289
+ throughput_per_second ,
290
+ tested_env ,
291
+ )
292
+ benchmark_results [avg_latency ] = [
293
+ tested_env ,
294
+ p90 ,
295
+ avg_tokens_per_second ,
296
+ throughput_per_second ,
297
+ standard_deviation ,
298
+ ]
299
+
300
+ if not best_tuned_combination :
301
+ best_tuned_combination = [
302
+ avg_latency ,
303
+ tensor_parallel_degree ,
304
+ sagemaker_model_server_workers ,
305
+ p90 ,
306
+ avg_tokens_per_second ,
307
+ throughput_per_second ,
308
+ standard_deviation ,
309
+ ]
310
+ else :
311
+ tuned_configuration = [
312
+ avg_latency ,
313
+ tensor_parallel_degree ,
314
+ sagemaker_model_server_workers ,
315
+ p90 ,
316
+ avg_tokens_per_second ,
317
+ throughput_per_second ,
318
+ standard_deviation ,
319
+ ]
320
+ if _more_performant (best_tuned_combination , tuned_configuration ):
321
+ best_tuned_combination = tuned_configuration
322
+ except LocalDeepPingException as e :
323
+ logger .warning (
324
+ "Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
325
+ "Failed to invoke the model server: %s" ,
326
+ tensor_parallel_degree ,
327
+ str (e ),
328
+ )
329
+ break
330
+ except LocalModelOutOfMemoryException as e :
331
+ logger .warning (
332
+ "Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
333
+ "Out of memory when loading the model: %s" ,
334
+ tensor_parallel_degree ,
335
+ str (e ),
336
+ )
337
+ break
338
+ except LocalModelInvocationException as e :
339
+ logger .warning (
340
+ "Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
341
+ "Failed to invoke the model server: %s"
342
+ "Please check that model server configurations are as expected "
343
+ "(Ex. serialization, deserialization, content_type, accept)." ,
344
+ tensor_parallel_degree ,
345
+ str (e ),
346
+ )
347
+ break
348
+ except LocalModelLoadException as e :
349
+ logger .warning (
350
+ "Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
351
+ "Failed to load the model: %s." ,
352
+ tensor_parallel_degree ,
353
+ str (e ),
354
+ )
355
+ break
356
+ except SkipTuningComboException as e :
357
+ logger .warning (
358
+ "Deployment with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
359
+ "was expected to be successful. However failed with: %s. "
360
+ "Trying next combination." ,
361
+ tensor_parallel_degree ,
362
+ str (e ),
363
+ )
364
+ break
365
+ except Exception :
366
+ logger .exception (
367
+ "Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
368
+ "with uncovered exception" ,
369
+ tensor_parallel_degree
370
+ )
371
+ break
372
+
373
+ if best_tuned_combination :
374
+ self .pysdk_model .env .update ({
375
+ "OPTION_TENSOR_PARALLEL_DEGREE" : str (best_tuned_combination [1 ])
376
+ })
377
+
378
+ _pretty_print_results_tgi (benchmark_results )
379
+ logger .info (
380
+ "Model Configuration: %s was most performant with avg latency: %s, "
381
+ "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
382
+ "standard deviation of request %s" ,
383
+ self .pysdk_model .env ,
384
+ best_tuned_combination [0 ],
385
+ best_tuned_combination [3 ],
386
+ best_tuned_combination [4 ],
387
+ best_tuned_combination [5 ],
388
+ best_tuned_combination [6 ],
389
+ )
390
+ else :
391
+ self .pysdk_model .env .update (initial_model_configuration )
392
+ logger .debug (
393
+ "Failed to gather any tuning results. "
394
+ "Please inspect the stack trace emitted from live logging for more details. "
395
+ "Falling back to default serving.properties: %s" ,
396
+ self .pysdk_model .env ,
397
+ )
398
+
399
+ return self .pysdk_model
250
400
251
401
@_capture_telemetry ("tgi_jumpstart.tune" )
252
402
def tune_for_tgi_jumpstart (self , max_tuning_duration : int = 1800 ):
@@ -352,7 +502,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
352
502
break
353
503
except LocalModelLoadException as e :
354
504
logger .warning (
355
- "Deployment unsuccessful with zSM_NUM_GPUS : %s. "
505
+ "Deployment unsuccessful with SM_NUM_GPUS : %s. "
356
506
"Failed to load the model: %s." ,
357
507
sm_num_gpus ,
358
508
str (e ),
@@ -438,4 +588,6 @@ def _build_for_jumpstart(self):
438
588
439
589
if self .model_server == ModelServer .TGI :
440
590
self .pysdk_model .tune = self .tune_for_tgi_jumpstart
591
+ elif self .model_server == ModelServer .DJL_SERVING :
592
+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
441
593
return self .pysdk_model
0 commit comments