Skip to content

Commit a2be460

Browse files
author
Jonathan Makunga
committed
Tune support for JS model with DJL DLC
1 parent 01eb282 commit a2be460

File tree

1 file changed

+154
-2
lines changed

1 file changed

+154
-2
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 154 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,157 @@ def _build_for_tgi_jumpstart(self):
246246

247247
@_capture_telemetry("djl_jumpstart.tune")
248248
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
249-
pass
249+
"""pass"""
250+
if self.mode != Mode.LOCAL_CONTAINER:
251+
logger.warning(
252+
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
253+
)
254+
return self.pysdk_model
255+
256+
initial_model_configuration = copy.deepcopy(self.pysdk_model.env)
257+
258+
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(self.js_model_config)
259+
260+
benchmark_results = {}
261+
best_tuned_combination = None
262+
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
263+
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
264+
if datetime.now() > timeout:
265+
logger.info("Max tuning duration reached. Tuning stopped.")
266+
break
267+
268+
sagemaker_model_server_workers = None
269+
self.pysdk_model.env.update({
270+
"OPTION_TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree)
271+
})
272+
273+
try:
274+
predictor = self.pysdk_model.deploy(
275+
model_data_download_timeout=max_tuning_duration
276+
)
277+
278+
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
279+
predictor, self.schema_builder.sample_input
280+
)
281+
throughput_per_second, standard_deviation = _concurrent_benchmark(
282+
predictor, self.schema_builder.sample_input
283+
)
284+
285+
tested_env = self.pysdk_model.env.copy()
286+
logger.info(
287+
"Average latency: %s, throughput/s: %s for configuration: %s",
288+
avg_latency,
289+
throughput_per_second,
290+
tested_env,
291+
)
292+
benchmark_results[avg_latency] = [
293+
tested_env,
294+
p90,
295+
avg_tokens_per_second,
296+
throughput_per_second,
297+
standard_deviation,
298+
]
299+
300+
if not best_tuned_combination:
301+
best_tuned_combination = [
302+
avg_latency,
303+
tensor_parallel_degree,
304+
sagemaker_model_server_workers,
305+
p90,
306+
avg_tokens_per_second,
307+
throughput_per_second,
308+
standard_deviation,
309+
]
310+
else:
311+
tuned_configuration = [
312+
avg_latency,
313+
tensor_parallel_degree,
314+
sagemaker_model_server_workers,
315+
p90,
316+
avg_tokens_per_second,
317+
throughput_per_second,
318+
standard_deviation,
319+
]
320+
if _more_performant(best_tuned_combination, tuned_configuration):
321+
best_tuned_combination = tuned_configuration
322+
except LocalDeepPingException as e:
323+
logger.warning(
324+
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
325+
"Failed to invoke the model server: %s",
326+
tensor_parallel_degree,
327+
str(e),
328+
)
329+
break
330+
except LocalModelOutOfMemoryException as e:
331+
logger.warning(
332+
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
333+
"Out of memory when loading the model: %s",
334+
tensor_parallel_degree,
335+
str(e),
336+
)
337+
break
338+
except LocalModelInvocationException as e:
339+
logger.warning(
340+
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
341+
"Failed to invoke the model server: %s"
342+
"Please check that model server configurations are as expected "
343+
"(Ex. serialization, deserialization, content_type, accept).",
344+
tensor_parallel_degree,
345+
str(e),
346+
)
347+
break
348+
except LocalModelLoadException as e:
349+
logger.warning(
350+
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
351+
"Failed to load the model: %s.",
352+
tensor_parallel_degree,
353+
str(e),
354+
)
355+
break
356+
except SkipTuningComboException as e:
357+
logger.warning(
358+
"Deployment with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
359+
"was expected to be successful. However failed with: %s. "
360+
"Trying next combination.",
361+
tensor_parallel_degree,
362+
str(e),
363+
)
364+
break
365+
except Exception:
366+
logger.exception(
367+
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
368+
"with uncovered exception",
369+
tensor_parallel_degree
370+
)
371+
break
372+
373+
if best_tuned_combination:
374+
self.pysdk_model.env.update({
375+
"OPTION_TENSOR_PARALLEL_DEGREE": str(best_tuned_combination[1])
376+
})
377+
378+
_pretty_print_results_tgi(benchmark_results)
379+
logger.info(
380+
"Model Configuration: %s was most performant with avg latency: %s, "
381+
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
382+
"standard deviation of request %s",
383+
self.pysdk_model.env,
384+
best_tuned_combination[0],
385+
best_tuned_combination[3],
386+
best_tuned_combination[4],
387+
best_tuned_combination[5],
388+
best_tuned_combination[6],
389+
)
390+
else:
391+
self.pysdk_model.env.update(initial_model_configuration)
392+
logger.debug(
393+
"Failed to gather any tuning results. "
394+
"Please inspect the stack trace emitted from live logging for more details. "
395+
"Falling back to default serving.properties: %s",
396+
self.pysdk_model.env,
397+
)
398+
399+
return self.pysdk_model
250400

251401
@_capture_telemetry("tgi_jumpstart.tune")
252402
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
@@ -352,7 +502,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
352502
break
353503
except LocalModelLoadException as e:
354504
logger.warning(
355-
"Deployment unsuccessful with zSM_NUM_GPUS: %s. "
505+
"Deployment unsuccessful with SM_NUM_GPUS: %s. "
356506
"Failed to load the model: %s.",
357507
sm_num_gpus,
358508
str(e),
@@ -438,4 +588,6 @@ def _build_for_jumpstart(self):
438588

439589
if self.model_server == ModelServer.TGI:
440590
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
591+
elif self.model_server == ModelServer.DJL_SERVING:
592+
self.pysdk_model.tune = self.tune_for_djl_jumpstart
441593
return self.pysdk_model

0 commit comments

Comments
 (0)