Skip to content

Commit 9b18684

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent 266fe2b commit 9b18684

File tree

3 files changed

+67
-87
lines changed

3 files changed

+67
-87
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@
4242
)
4343
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
4444
from sagemaker.serve.utils.tuning import (
45-
_more_performant_benchmark,
4645
_pretty_print_benchmark_results,
47-
_run_serial_and_concurrent_benchmarks,
4846
sharded_supported,
47+
_serial_benchmark,
48+
_concurrent_benchmark,
49+
_more_performant,
4950
)
5051
from sagemaker.serve.utils.types import ModelServer
5152
from sagemaker.base_predictor import PredictorBase
@@ -298,19 +299,52 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
298299
try:
299300
logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree)
300301

301-
result = _run_serial_and_concurrent_benchmarks(
302-
self.pysdk_model, self.schema_builder.sample_input, max_tuning_duration
302+
predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)
303+
304+
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
305+
predictor, self.schema_builder.sample_input
306+
)
307+
throughput_per_second, standard_deviation = _concurrent_benchmark(
308+
predictor, self.schema_builder.sample_input
303309
)
304-
benchmark_results[result["AVG_LATENCY"]] = {
305-
"TESTED_ENV": result["TESTED_ENV"],
306-
"P90": result["P90"],
307-
"AVG_TOKENS_PER_SECOND": result["AVG_TOKENS_PER_SECOND"],
308-
"THROUGHPUT_PER_SECOND": result["THROUGHPUT_PER_SECOND"],
309-
"STD_DEVIATION": result["STD_DEVIATION"],
310-
}
311-
312-
result[num_shard_env_var_name] = tensor_parallel_degree
313-
best_tuned_combination = _more_performant_benchmark(best_tuned_combination, result)
310+
311+
tested_env = copy.deepcopy(self.pysdk_model.env)
312+
logger.info(
313+
"Average latency: %s, throughput/s: %s for configuration: %s",
314+
avg_latency,
315+
throughput_per_second,
316+
tested_env,
317+
)
318+
benchmark_results[avg_latency] = [
319+
tested_env,
320+
p90,
321+
avg_tokens_per_second,
322+
throughput_per_second,
323+
standard_deviation,
324+
]
325+
326+
if not best_tuned_combination:
327+
best_tuned_combination = [
328+
avg_latency,
329+
tensor_parallel_degree,
330+
None,
331+
p90,
332+
avg_tokens_per_second,
333+
throughput_per_second,
334+
standard_deviation,
335+
]
336+
else:
337+
tuned_configuration = [
338+
avg_latency,
339+
tensor_parallel_degree,
340+
None,
341+
p90,
342+
avg_tokens_per_second,
343+
throughput_per_second,
344+
standard_deviation,
345+
]
346+
if _more_performant(best_tuned_combination, tuned_configuration):
347+
best_tuned_combination = tuned_configuration
314348
except LocalDeepPingException as e:
315349
logger.warning(
316350
"Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s",
@@ -360,21 +394,19 @@ def _tune_for_js(self, max_tuning_duration: int = 1800):
360394
)
361395

362396
if best_tuned_combination:
363-
self.pysdk_model.env.update(
364-
{num_shard_env_var_name: str(best_tuned_combination[num_shard_env_var_name])}
365-
)
397+
self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])})
366398

367399
_pretty_print_benchmark_results(benchmark_results, [num_shard_env_var_name])
368400
logger.info(
369401
"Model Configuration: %s was most performant with avg latency: %s, "
370402
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
371403
"standard deviation of request %s",
372404
self.pysdk_model.env,
373-
best_tuned_combination["AVG_LATENCY"],
374-
best_tuned_combination["P90"],
375-
best_tuned_combination["AVG_TOKENS_PER_SECOND"],
376-
best_tuned_combination["THROUGHPUT_PER_SECOND"],
377-
best_tuned_combination["STD_DEVIATION"],
405+
best_tuned_combination[0],
406+
best_tuned_combination[3],
407+
best_tuned_combination[4],
408+
best_tuned_combination[5],
409+
best_tuned_combination[6],
378410
)
379411
else:
380412
self.pysdk_model.env.update(initial_env_vars)

src/sagemaker/serve/utils/tuning.py

Lines changed: 11 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Holds mixin logic to support deployment of Model ID"""
22
from __future__ import absolute_import
33

4-
import copy
54
import logging
65
from time import perf_counter
76
import collections
@@ -118,13 +117,13 @@ def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
118117

119118
for key, value in ordered.items():
120119
avg_latencies.append(key)
121-
p90s.append(value["P90"])
122-
avg_tokens_per_seconds.append(value["AVG_TOKENS_PER_SECOND"])
123-
throughput_per_seconds.append(value["THROUGHPUT_PER_SECOND"])
124-
standard_deviations.append(value["STD_DEVIATION"])
120+
p90s.append(value[1])
121+
avg_tokens_per_seconds.append(value[2])
122+
throughput_per_seconds.append(value[3])
123+
standard_deviations.append(value[4])
125124

126125
for model_env_var in __env_var_data:
127-
__env_var_data[model_env_var].append(value["TESTED_ENV"][model_env_var])
126+
__env_var_data[model_env_var].append(value[0][model_env_var])
128127

129128
df = pd.DataFrame(
130129
{
@@ -137,13 +136,13 @@ def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
137136
}
138137
)
139138

140-
separator = "=" * 78
141-
log_message = (
142-
f"\n{separator} Benchmark Results {separator}\n"
143-
f"{df.to_string()}\n"
144-
f"{separator}{separator}\n"
139+
logger.info(
140+
"\n================================================================== Benchmark "
141+
"Results ==================================================================\n%s"
142+
"\n============================================================================"
143+
"===========================================================================\n",
144+
df.to_string(),
145145
)
146-
logger.info(log_message)
147146

148147

149148
def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int:
@@ -266,53 +265,6 @@ def _more_performant(best_tuned_configuration: list, tuned_configuration: list)
266265
return tuned_avg_latency <= best_avg_latency
267266

268267

269-
def _more_performant_benchmark(
270-
best_tuned_configuration: dict, current_tuned_configuration: dict
271-
) -> dict:
272-
"""Returns the configuration with the lowest latency"""
273-
if best_tuned_configuration is None:
274-
return current_tuned_configuration
275-
276-
best_avg_latency = best_tuned_configuration["AGV_LATENCY"]
277-
current_tuned_avg_latency = current_tuned_configuration["AGV_LATENCY"]
278-
best_standard_deviation = best_tuned_configuration["STD_DEVIATION"]
279-
current_tuned_standard_deviation = current_tuned_configuration["STD_DEVIATION"]
280-
281-
if _within_margins(MARGIN, 5, current_tuned_avg_latency, best_avg_latency):
282-
if current_tuned_standard_deviation <= best_standard_deviation:
283-
return current_tuned_configuration
284-
return best_tuned_configuration
285-
286-
if current_tuned_avg_latency <= best_avg_latency:
287-
return current_tuned_configuration
288-
return best_tuned_configuration
289-
290-
291-
def _run_serial_and_concurrent_benchmarks(pysdk_model, sample_input, max_tuning_duration) -> dict:
292-
"""Run the benchmarks"""
293-
predictor = pysdk_model.deploy(model_data_download_timeout=max_tuning_duration)
294-
295-
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(predictor, sample_input)
296-
throughput_per_second, standard_deviation = _concurrent_benchmark(predictor, sample_input)
297-
298-
tested_env = copy.deepcopy(pysdk_model.env)
299-
logger.info(
300-
"Average latency: %s, throughput/s: %s for configuration: %s",
301-
avg_latency,
302-
throughput_per_second,
303-
tested_env,
304-
)
305-
306-
return {
307-
"AVG_LATENCY": avg_latency,
308-
"TESTED_ENV": tested_env,
309-
"P90": p90,
310-
"AVG_TOKENS_PER_SECOND": avg_tokens_per_second,
311-
"THROUGHPUT_PER_SECOND": throughput_per_second,
312-
"STD_DEVIATION": standard_deviation,
313-
}
314-
315-
316268
def sharded_supported(model_id: str, config_dict: dict) -> bool:
317269
"""Check if sharded is supported for this ``Model``"""
318270
model_type = config_dict.get("model_type", None)

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
mock_tgi_most_performant_model_serving_properties = {
3737
"SAGEMAKER_PROGRAM": "inference.py",
3838
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
39-
"SM_NUM_GPUS": "4",
39+
"SM_NUM_GPUS": "2",
4040
}
4141
mock_tgi_model_serving_properties = {
4242
"SAGEMAKER_PROGRAM": "inference.py",
@@ -530,8 +530,4 @@ def test_tune_for_djl_js_local_container_sharded_not_enabled(
530530
mock_pre_trained_model.return_value.env = mock_djl_model_serving_properties
531531

532532
tuned_model = model.tune()
533-
assert tuned_model.env == {
534-
"SAGEMAKER_PROGRAM": "inference.py",
535-
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
536-
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
537-
}
533+
assert tuned_model.env == mock_djl_most_performant_model_serving_properties

0 commit comments

Comments
 (0)