Skip to content

Commit 332b2ec

Browse files
author
Jonathan Makunga
committed
Pretty Print JS TGI Benchmark results
1 parent a2be460 commit 332b2ec

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb
3535
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
3636
from sagemaker.serve.utils.tuning import _serial_benchmark, _concurrent_benchmark, _more_performant, \
37-
_pretty_print_results_tgi
37+
_pretty_print_results_tgi, _pretty_print_results_tgi_js
3838
from sagemaker.serve.utils.types import ModelServer
3939
from sagemaker.base_predictor import PredictorBase
4040
from sagemaker.jumpstart.model import JumpStartModel
@@ -392,7 +392,7 @@ def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
392392
logger.debug(
393393
"Failed to gather any tuning results. "
394394
"Please inspect the stack trace emitted from live logging for more details. "
395-
"Falling back to default serving.properties: %s",
395+
"Falling back to default model environment variable configurations: %s",
396396
self.pysdk_model.env,
397397
)
398398

@@ -517,7 +517,8 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
517517
str(e),
518518
)
519519
break
520-
except Exception:
520+
except Exception as e:
521+
logger.exception(e)
521522
logger.exception(
522523
"Deployment unsuccessful with SM_NUM_GPUS: %s. "
523524
"with uncovered exception",
@@ -530,7 +531,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
530531
"SM_NUM_GPUS": str(best_tuned_combination[1])
531532
})
532533

533-
_pretty_print_results_tgi(benchmark_results)
534+
_pretty_print_results_tgi_js(benchmark_results)
534535
logger.info(
535536
"Model Configuration: %s was most performant with avg latency: %s, "
536537
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
@@ -547,7 +548,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
547548
logger.debug(
548549
"Failed to gather any tuning results. "
549550
"Please inspect the stack trace emitted from live logging for more details. "
550-
"Falling back to default serving.properties: %s",
551+
"Falling back to default model environment variable configurations: %s",
551552
self.pysdk_model.env,
552553
)
553554

src/sagemaker/serve/utils/tuning.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,43 @@ def _pretty_print_results_tgi(results: dict):
9898
)
9999

100100

101+
def _pretty_print_results_tgi_js(results: dict):
102+
"""Placeholder docstring"""
103+
avg_latencies = []
104+
sm_num_gpus = []
105+
p90s = []
106+
avg_tokens_per_seconds = []
107+
throughput_per_seconds = []
108+
standard_deviations = []
109+
ordered = collections.OrderedDict(sorted(results.items()))
110+
111+
for key, value in ordered.items():
112+
avg_latencies.append(key)
113+
sm_num_gpus.append(value[0]["SM_NUM_GPUS"])
114+
p90s.append(value[1])
115+
avg_tokens_per_seconds.append(value[2])
116+
throughput_per_seconds.append(value[3])
117+
standard_deviations.append(value[4])
118+
119+
df = pd.DataFrame(
120+
{
121+
"AverageLatency (Serial)": avg_latencies,
122+
"P90_Latency (Serial)": p90s,
123+
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
124+
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
125+
"StandardDeviationResponse (Concurrent)": standard_deviations,
126+
"SM_NUM_GPUS": sm_num_gpus,
127+
}
128+
)
129+
logger.info(
130+
"\n================================================================== Benchmark "
131+
"Results ==================================================================\n%s"
132+
"\n============================================================================"
133+
"===========================================================================\n",
134+
df.to_string(),
135+
)
136+
137+
101138
def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
102139
"""Placeholder docstring"""
103140
est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2

0 commit comments

Comments
 (0)