Skip to content

Commit 4be1679

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent 1ec29e3 commit 4be1679

File tree

2 files changed

+18
-216
lines changed

2 files changed

+18
-216
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from sagemaker.model import Model
2323
from sagemaker import model_uris
2424
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
25-
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
25+
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees, \
26+
_get_admissible_dtypes
2627
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
2728
from sagemaker.serve.mode.function_pointers import Mode
2829
from sagemaker.serve.utils.exceptions import (
@@ -42,8 +43,7 @@
4243
_serial_benchmark,
4344
_concurrent_benchmark,
4445
_more_performant,
45-
_pretty_print_results_djl_js,
46-
_pretty_print_results_tgi_js
46+
_pretty_print_benchmark_results
4747
)
4848
from sagemaker.serve.utils.types import ModelServer
4949
from sagemaker.base_predictor import PredictorBase
@@ -254,169 +254,6 @@ def _build_for_tgi_jumpstart(self):
254254

255255
self.pysdk_model.env.update(env)
256256

257-
def _log_delete_me(self, data: any):
258-
"""Placeholder docstring"""
259-
logger.debug("*****************************************")
260-
logger.debug(data)
261-
logger.debug("*****************************************")
262-
263-
@_capture_telemetry("djl_jumpstart.tune")
264-
def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800):
265-
"""pass"""
266-
if self.mode != Mode.LOCAL_CONTAINER:
267-
logger.warning(
268-
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
269-
)
270-
return self.pysdk_model
271-
272-
initial_model_configuration = copy.deepcopy(self.pysdk_model.env)
273-
self._log_delete_me(f"initial_model_configuration: {initial_model_configuration}")
274-
275-
self._log_delete_me(f"self.js_model_config: {self.js_model_config}")
276-
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(self.js_model_config)
277-
self._log_delete_me(f"admissible_tensor_parallel_degrees: {admissible_tensor_parallel_degrees}")
278-
279-
benchmark_results = {}
280-
best_tuned_combination = None
281-
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
282-
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
283-
self._log_delete_me(f"tensor_parallel_degree: {tensor_parallel_degree}")
284-
if datetime.now() > timeout:
285-
logger.info("Max tuning duration reached. Tuning stopped.")
286-
break
287-
288-
self.pysdk_model.env.update({
289-
"OPTION_TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree)
290-
})
291-
292-
try:
293-
predictor = self.pysdk_model.deploy(
294-
model_data_download_timeout=max_tuning_duration
295-
)
296-
297-
avg_latency, p90, avg_tokens_per_second = _serial_benchmark(
298-
predictor, self.schema_builder.sample_input
299-
)
300-
throughput_per_second, standard_deviation = _concurrent_benchmark(
301-
predictor, self.schema_builder.sample_input
302-
)
303-
304-
tested_env = self.pysdk_model.env.copy()
305-
logger.info(
306-
"Average latency: %s, throughput/s: %s for configuration: %s",
307-
avg_latency,
308-
throughput_per_second,
309-
tested_env,
310-
)
311-
benchmark_results[avg_latency] = [
312-
tested_env,
313-
p90,
314-
avg_tokens_per_second,
315-
throughput_per_second,
316-
standard_deviation,
317-
]
318-
319-
if not best_tuned_combination:
320-
best_tuned_combination = [
321-
avg_latency,
322-
tensor_parallel_degree,
323-
None,
324-
p90,
325-
avg_tokens_per_second,
326-
throughput_per_second,
327-
standard_deviation,
328-
]
329-
else:
330-
tuned_configuration = [
331-
avg_latency,
332-
tensor_parallel_degree,
333-
None,
334-
p90,
335-
avg_tokens_per_second,
336-
throughput_per_second,
337-
standard_deviation,
338-
]
339-
if _more_performant(best_tuned_combination, tuned_configuration):
340-
best_tuned_combination = tuned_configuration
341-
except LocalDeepPingException as e:
342-
logger.warning(
343-
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
344-
"Failed to invoke the model server: %s",
345-
tensor_parallel_degree,
346-
str(e),
347-
)
348-
break
349-
except LocalModelOutOfMemoryException as e:
350-
logger.warning(
351-
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
352-
"Out of memory when loading the model: %s",
353-
tensor_parallel_degree,
354-
str(e),
355-
)
356-
break
357-
except LocalModelInvocationException as e:
358-
logger.warning(
359-
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
360-
"Failed to invoke the model server: %s"
361-
"Please check that model server configurations are as expected "
362-
"(Ex. serialization, deserialization, content_type, accept).",
363-
tensor_parallel_degree,
364-
str(e),
365-
)
366-
break
367-
except LocalModelLoadException as e:
368-
logger.warning(
369-
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
370-
"Failed to load the model: %s.",
371-
tensor_parallel_degree,
372-
str(e),
373-
)
374-
break
375-
except SkipTuningComboException as e:
376-
logger.warning(
377-
"Deployment with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
378-
"was expected to be successful. However failed with: %s. "
379-
"Trying next combination.",
380-
tensor_parallel_degree,
381-
str(e),
382-
)
383-
break
384-
except Exception:
385-
logger.exception(
386-
"Deployment unsuccessful with OPTION_TENSOR_PARALLEL_DEGREE: %s. "
387-
"with uncovered exception",
388-
tensor_parallel_degree
389-
)
390-
break
391-
392-
if best_tuned_combination:
393-
self.pysdk_model.env.update({
394-
"OPTION_TENSOR_PARALLEL_DEGREE": str(best_tuned_combination[1])
395-
})
396-
397-
_pretty_print_results_djl_js(benchmark_results)
398-
logger.info(
399-
"Model Configuration: %s was most performant with avg latency: %s, "
400-
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
401-
"standard deviation of request %s",
402-
self.pysdk_model.env,
403-
best_tuned_combination[0],
404-
best_tuned_combination[3],
405-
best_tuned_combination[4],
406-
best_tuned_combination[5],
407-
best_tuned_combination[6],
408-
)
409-
else:
410-
self.pysdk_model.env.update(initial_model_configuration)
411-
logger.debug(
412-
"Failed to gather any tuning results. "
413-
"Please inspect the stack trace emitted from live logging for more details. "
414-
"Falling back to default model environment variable configurations: %s",
415-
self.pysdk_model.env,
416-
)
417-
418-
return self.pysdk_model
419-
420257
@_capture_telemetry("tgi_jumpstart.tune")
421258
def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
422259
"""Placeholder docstring"""
@@ -426,19 +263,13 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
426263
)
427264
return self.pysdk_model
428265

429-
initial_model_configuration = copy.deepcopy(self.pysdk_model.env)
430-
self._log_delete_me(f"initial_model_configuration: {initial_model_configuration}")
431-
432-
self._log_delete_me(f"self.js_model_config: {self.js_model_config}")
433-
266+
initial_env_vars = copy.deepcopy(self.pysdk_model.env)
434267
admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees(self.js_model_config)
435-
self._log_delete_me(f"admissible_tensor_parallel_degrees: {admissible_tensor_parallel_degrees}")
436268

437269
benchmark_results = {}
438270
best_tuned_combination = None
439271
timeout = datetime.now() + timedelta(seconds=max_tuning_duration)
440272
for tensor_parallel_degree in admissible_tensor_parallel_degrees:
441-
self._log_delete_me(f"tensor_parallel_degree: {tensor_parallel_degree}")
442273
if datetime.now() > timeout:
443274
logger.info("Max tuning duration reached. Tuning stopped.")
444275
break
@@ -554,7 +385,9 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
554385
"SM_NUM_GPUS": str(best_tuned_combination[1])
555386
})
556387

557-
_pretty_print_results_tgi_js(benchmark_results)
388+
_pretty_print_benchmark_results(benchmark_results, [
389+
"SM_NUM_GPUS"
390+
])
558391
logger.info(
559392
"Model Configuration: %s was most performant with avg latency: %s, "
560393
"p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
@@ -567,7 +400,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
567400
best_tuned_combination[6],
568401
)
569402
else:
570-
self.pysdk_model.env.update(initial_model_configuration)
403+
self.pysdk_model.env.update(initial_env_vars)
571404
logger.debug(
572405
"Failed to gather any tuning results. "
573406
"Please inspect the stack trace emitted from live logging for more details. "
@@ -612,6 +445,4 @@ def _build_for_jumpstart(self):
612445

613446
if self.model_server == ModelServer.TGI:
614447
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
615-
elif self.model_server == ModelServer.DJL_SERVING:
616-
self.pysdk_model.tune = self.tune_for_djl_jumpstart
617448
return self.pysdk_model

src/sagemaker/serve/utils/tuning.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -98,47 +98,16 @@ def _pretty_print_results_tgi(results: dict):
9898
)
9999

100100

101-
def _pretty_print_results_tgi_js(results: dict):
101+
def _pretty_print_benchmark_results(results: dict, model_env_vars=None):
102102
"""Placeholder docstring"""
103-
avg_latencies = []
104-
sm_num_gpus = []
105-
p90s = []
106-
avg_tokens_per_seconds = []
107-
throughput_per_seconds = []
108-
standard_deviations = []
109-
ordered = collections.OrderedDict(sorted(results.items()))
110-
111-
for key, value in ordered.items():
112-
avg_latencies.append(key)
113-
sm_num_gpus.append(value[0]["SM_NUM_GPUS"])
114-
p90s.append(value[1])
115-
avg_tokens_per_seconds.append(value[2])
116-
throughput_per_seconds.append(value[3])
117-
standard_deviations.append(value[4])
118-
119-
df = pd.DataFrame(
120-
{
121-
"AverageLatency (Serial)": avg_latencies,
122-
"P90_Latency (Serial)": p90s,
123-
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
124-
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
125-
"StandardDeviationResponse (Concurrent)": standard_deviations,
126-
"SM_NUM_GPUS": sm_num_gpus,
127-
}
128-
)
129-
logger.info(
130-
"\n================================================================== Benchmark "
131-
"Results ==================================================================\n%s"
132-
"\n============================================================================"
133-
"===========================================================================\n",
134-
df.to_string(),
135-
)
103+
if model_env_vars is None:
104+
model_env_vars = []
136105

106+
__env_var_data = {}
107+
for e in model_env_vars:
108+
__env_var_data[e] = []
137109

138-
def _pretty_print_results_djl_js(results: dict):
139-
"""Placeholder docstring"""
140110
avg_latencies = []
141-
option_tp_degree = []
142111
p90s = []
143112
avg_tokens_per_seconds = []
144113
throughput_per_seconds = []
@@ -147,20 +116,22 @@ def _pretty_print_results_djl_js(results: dict):
147116

148117
for key, value in ordered.items():
149118
avg_latencies.append(key)
150-
option_tp_degree.append(value[0]["OPTION_TENSOR_PARALLEL_DEGREE"])
151119
p90s.append(value[1])
152120
avg_tokens_per_seconds.append(value[2])
153121
throughput_per_seconds.append(value[3])
154122
standard_deviations.append(value[4])
155123

124+
for k in __env_var_data:
125+
__env_var_data[k].append(value[0][k])
126+
156127
df = pd.DataFrame(
157128
{
158129
"AverageLatency (Serial)": avg_latencies,
159130
"P90_Latency (Serial)": p90s,
160131
"AverageTokensPerSecond (Serial)": avg_tokens_per_seconds,
161132
"ThroughputPerSecond (Concurrent)": throughput_per_seconds,
162133
"StandardDeviationResponse (Concurrent)": standard_deviations,
163-
"OPTION_TENSOR_PARALLEL_DEGREE": option_tp_degree,
134+
**__env_var_data
164135
}
165136
)
166137
logger.info(

0 commit comments

Comments
 (0)