Skip to content

Commit 4d06e66

Browse files
author
Jonathan Makunga
committed
Pretty Print JS DJL Benchmark results
1 parent 332b2ec commit 4d06e66

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,26 @@
2525
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
2626
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
2727
from sagemaker.serve.mode.function_pointers import Mode
28-
from sagemaker.serve.utils.exceptions import LocalDeepPingException, LocalModelOutOfMemoryException, \
29-
LocalModelInvocationException, LocalModelLoadException, SkipTuningComboException
28+
from sagemaker.serve.utils.exceptions import (
29+
LocalDeepPingException,
30+
LocalModelOutOfMemoryException,
31+
LocalModelInvocationException,
32+
LocalModelLoadException,
33+
SkipTuningComboException
34+
)
3035
from sagemaker.serve.utils.predictors import (
3136
DjlLocalModePredictor,
3237
TgiLocalModePredictor,
3338
)
3439
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
3540
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
36-
from sagemaker.serve.utils.tuning import _serial_benchmark, _concurrent_benchmark, _more_performant, \
37-
_pretty_print_results_tgi, _pretty_print_results_tgi_js
41+
from sagemaker.serve.utils.tuning import (
42+
_serial_benchmark,
43+
_concurrent_benchmark,
44+
_more_performant,
45+
_pretty_print_results_djl_js,
46+
_pretty_print_results_tgi_js
47+
)
3848
from sagemaker.serve.utils.types import ModelServer
3949
from sagemaker.base_predictor import PredictorBase
4050
from sagemaker.jumpstart.model import JumpStartModel
@@ -375,7 +385,7 @@ def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
375385
"OPTION_TENSOR_PARALLEL_DEGREE": str(best_tuned_combination[1])
376386
})
377387

378-
_pretty_print_results_tgi(benchmark_results)
388+
_pretty_print_results_djl_js(benchmark_results)
379389
logger.info(
380390
"Model Configuration: %s was most performant with avg latency: %s, "
381391
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "

src/sagemaker/serve/utils/tuning.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,43 @@ def _pretty_print_results_tgi_js(results: dict):
135135
)
136136

137137

138+
def _pretty_print_results_djl_js(results: dict):
139+
"""Placeholder docstring"""
140+
avg_latencies = []
141+
option_tp_degree = []
142+
p90s = []
143+
avg_tokens_per_seconds = []
144+
throughput_per_seconds = []
145+
standard_deviations = []
146+
ordered = collections.OrderedDict(sorted(results.items()))
147+
148+
for key, value in ordered.items():
149+
avg_latencies.append(key)
150+
option_tp_degree.append(value[0]["OPTION_TENSOR_PARALLEL_DEGREE"])
151+
p90s.append(value[1])
152+
avg_tokens_per_seconds.append(value[2])
153+
throughput_per_seconds.append(value[3])
154+
standard_deviations.append(value[4])
155+
156+
df = pd.DataFrame(
157+
{
158+
"AverageLatency (Serial)": avg_latencies,
159+
"P90_Latency (Serial)": p90s,
160+
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
161+
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
162+
"StandardDeviationResponse (Concurrent)": standard_deviations,
163+
"OPTION_TENSOR_PARALLEL_DEGREE": option_tp_degree,
164+
}
165+
)
166+
logger.info(
167+
"\n================================================================== Benchmark "
168+
"Results ==================================================================\n%s"
169+
"\n============================================================================"
170+
"===========================================================================\n",
171+
df.to_string(),
172+
)
173+
174+
138175
def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
139176
"""Placeholder docstring"""
140177
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2

0 commit comments

Comments
 (0)